| """ClinicalTrials MCP server client wrapper.""" |
|
|
| import asyncio |
| import json |
| import time |
|
|
| import httpx |
| import structlog |
|
|
| logger = structlog.get_logger("trialpath.mcp") |
|
|
| from trialpath.config import MCP_URL |
| from trialpath.models.search_anchors import SearchAnchors |
| from trialpath.models.trial_candidate import ( |
| AgeRange, |
| EligibilityText, |
| TrialCandidate, |
| TrialLocation, |
| ) |
|
|
| _MAX_RETRIES = 3 |
| _RETRY_BACKOFF_BASE = 2.0 |
| _CT_GOV_BASE = "https://clinicaltrials.gov/api/v2" |
| _CT_GOV_HEADERS = { |
| "User-Agent": "TrialPath/0.1 (clinical trial matching; +https://github.com/trialpath)", |
| } |
|
|
|
|
| class MCPError(Exception): |
| """Error returned by the MCP server.""" |
|
|
| def __init__(self, code: int, message: str): |
| self.code = code |
| self.message = message |
| super().__init__(f"MCP Error {code}: {message}") |
|
|
|
|
| class ClinicalTrialsMCPClient: |
| """Client for ClinicalTrials MCP Server.""" |
|
|
| def __init__(self, mcp_url: str | None = None): |
| self.mcp_url = mcp_url or MCP_URL |
|
|
| @staticmethod |
| def normalize_trial(raw: dict) -> TrialCandidate: |
| """Convert raw API response to TrialCandidate model.""" |
| protocol = raw.get("protocolSection", {}) |
| id_module = protocol.get("identificationModule", {}) |
| status_module = protocol.get("statusModule", {}) |
| design_module = protocol.get("designModule", {}) |
| eligibility_module = protocol.get("eligibilityModule", {}) |
| conditions_module = protocol.get("conditionsModule", {}) |
| contacts_module = protocol.get("contactsLocationsModule", {}) |
|
|
| nct_id = id_module.get("nctId", raw.get("nctId", "")) |
| title = id_module.get("briefTitle", raw.get("title", "")) |
|
|
| conditions = conditions_module.get("conditions", raw.get("conditions", [])) |
|
|
| phase_list = design_module.get("phases", []) |
| phase = phase_list[0] if phase_list else raw.get("phase") |
|
|
| status = status_module.get("overallStatus", raw.get("status")) |
|
|
| locations: list[TrialLocation] = [] |
| for loc in contacts_module.get("locations", []): |
| locations.append( |
| TrialLocation( |
| country=loc.get("country", ""), |
| city=loc.get("city"), |
| ) |
| ) |
|
|
| age_range = None |
| min_age_str = eligibility_module.get("minimumAge", "") |
| max_age_str = eligibility_module.get("maximumAge", "") |
| if min_age_str or max_age_str: |
| min_age = _parse_age(min_age_str) |
| max_age = _parse_age(max_age_str) |
| if min_age is not None or max_age is not None: |
| age_range = AgeRange(min=min_age, max=max_age) |
|
|
| eligibility_text = None |
| criteria_text = eligibility_module.get("eligibilityCriteria", "") |
| if criteria_text: |
| parts = criteria_text.split("Exclusion Criteria", 1) |
| inclusion = parts[0].replace("Inclusion Criteria:", "").strip() |
| exclusion = parts[1].strip() if len(parts) > 1 else "" |
| eligibility_text = EligibilityText( |
| inclusion=inclusion, |
| exclusion=exclusion, |
| ) |
|
|
| fingerprint = f"{title} {' '.join(conditions)} {phase or ''}" |
|
|
| return TrialCandidate( |
| nct_id=nct_id, |
| title=title, |
| conditions=conditions, |
| phase=phase, |
| status=status, |
| locations=locations, |
| age_range=age_range, |
| fingerprint_text=fingerprint.strip(), |
| eligibility_text=eligibility_text, |
| ) |
|
|
| async def search(self, anchors: SearchAnchors) -> list[dict]: |
| """Convert SearchAnchors to MCP search_studies call.""" |
| query_parts = [anchors.condition] |
| if anchors.subtype: |
| query_parts.append(anchors.subtype) |
| if anchors.biomarkers: |
| query_parts.extend(anchors.biomarkers) |
| if anchors.interventions: |
| query_parts.extend(anchors.interventions[:2]) |
|
|
| query = " ".join(query_parts) |
|
|
| filters = [] |
| if anchors.trial_filters.recruitment_status: |
| status_filter = " OR ".join( |
| f"AREA[OverallStatus]{s}" for s in anchors.trial_filters.recruitment_status |
| ) |
| filters.append(f"({status_filter})") |
|
|
| if anchors.trial_filters.phase: |
| phase_filter = " OR ".join(f"AREA[Phase]{p}" for p in anchors.trial_filters.phase) |
| filters.append(f"({phase_filter})") |
|
|
| if anchors.age is not None: |
| filters.append(f"AREA[MinimumAge]RANGE[MIN, {anchors.age}]") |
| filters.append(f"AREA[MaximumAge]RANGE[{anchors.age}, MAX]") |
|
|
| filter_str = " AND ".join(filters) if filters else None |
|
|
| params: dict = { |
| "query": query, |
| "pageSize": 50, |
| "sort": "LastUpdateDate:desc", |
| } |
| if filter_str: |
| params["filter"] = filter_str |
| if anchors.geography: |
| params["country"] = anchors.geography.country |
|
|
| result = await self._call_tool("clinicaltrials_search_studies", params) |
| return result.get("studies", []) |
|
|
| async def get_study(self, nct_id: str) -> dict: |
| """Fetch full study details by NCT ID.""" |
| result = await self._call_tool( |
| "clinicaltrials_get_study", |
| { |
| "nctIds": [nct_id], |
| "summaryOnly": False, |
| }, |
| ) |
| studies = result.get("studies", []) |
| return studies[0] if studies else {} |
|
|
| async def find_eligible( |
| self, |
| age: int, |
| sex: str, |
| conditions: list[str], |
| country: str, |
| max_results: int = 20, |
| ) -> dict: |
| """Use find_eligible_studies for demographic-based matching.""" |
| return await self._call_tool( |
| "clinicaltrials_find_eligible_studies", |
| { |
| "age": age, |
| "sex": sex, |
| "conditions": conditions, |
| "location": {"country": country}, |
| "recruitingOnly": True, |
| "maxResults": max_results, |
| }, |
| ) |
|
|
| async def compare_studies(self, nct_ids: list[str]) -> dict: |
| """Compare 2-5 studies side by side.""" |
| return await self._call_tool( |
| "clinicaltrials_compare_studies", |
| { |
| "nctIds": nct_ids, |
| "compareFields": "all", |
| }, |
| ) |
|
|
| @staticmethod |
| def _parse_sse_events(text: str) -> dict: |
| """Extract the last JSON-RPC data payload from an SSE event stream.""" |
| last_data: dict = {} |
| for line in text.splitlines(): |
| if line.startswith("data: "): |
| payload = line[len("data: ") :] |
| try: |
| last_data = json.loads(payload) |
| except json.JSONDecodeError: |
| continue |
| return last_data |
|
|
| async def _call_tool(self, tool_name: str, params: dict) -> dict: |
| """Call an MCP tool via JSON-RPC over Streamable HTTP with retry logic. |
| |
| Handles both plain JSON and SSE (text/event-stream) responses |
| for compatibility with MCP Streamable HTTP transport. |
| """ |
| last_error = None |
| for attempt in range(_MAX_RETRIES): |
| start = time.monotonic() |
| try: |
| async with httpx.AsyncClient(timeout=httpx.Timeout(120.0, connect=10.0)) as client: |
| response = await client.post( |
| f"{self.mcp_url}/mcp", |
| json={ |
| "jsonrpc": "2.0", |
| "method": "tools/call", |
| "params": { |
| "name": tool_name, |
| "arguments": params, |
| }, |
| "id": 1, |
| }, |
| headers={ |
| "Accept": "application/json, text/event-stream", |
| }, |
| ) |
| response.raise_for_status() |
| text = response.text |
|
|
| content_type = response.headers.get("content-type", "") |
| if "text/event-stream" in content_type: |
| data = self._parse_sse_events(text) |
| else: |
| data = json.loads(text) |
|
|
| if not data: |
| raise MCPError(code=-1, message="Empty response from MCP server") |
|
|
| if "error" in data: |
| raise MCPError( |
| code=data["error"].get("code", -1), |
| message=data["error"].get("message", "Unknown MCP error"), |
| ) |
|
|
| elapsed = time.monotonic() - start |
| logger.info( |
| "mcp_tool_call", |
| tool=tool_name, |
| attempt=attempt + 1, |
| duration_s=round(elapsed, 2), |
| ) |
|
|
| result = data.get("result", {}) |
| return self._extract_mcp_content(result) |
| except (httpx.ConnectError, httpx.TimeoutException) as e: |
| last_error = e |
| if attempt < _MAX_RETRIES - 1: |
| await asyncio.sleep(_RETRY_BACKOFF_BASE**attempt) |
| continue |
| raise |
| assert last_error is not None |
| raise last_error |
|
|
| @staticmethod |
| def _extract_mcp_content(result: dict) -> dict: |
| """Extract actual data from MCP tool result content wrapper. |
| |
| MCP wraps tool results as: |
| {"content": [{"type": "text", "text": "...json..."}]} |
| """ |
| content_list = result.get("content", []) |
| for item in content_list: |
| if item.get("type") == "text": |
| try: |
| return json.loads(item["text"]) |
| except (json.JSONDecodeError, KeyError): |
| continue |
| return result |
|
|
| async def search_direct(self, anchors: SearchAnchors) -> list[dict]: |
| """Search ClinicalTrials.gov API v2 directly with retry. |
| |
| Uses structured query parameters: |
| - query.cond: condition (+ subtype) |
| - query.intr: first intervention (drug name) |
| - query.term: eligibility keywords (ECOG, stage, biomarkers) |
| |
| Falls back to query.term-only when new fields are absent (backward |
| compatible). |
| """ |
| import requests |
|
|
| |
| cond_parts = [anchors.condition] |
| if anchors.subtype: |
| cond_parts.append(anchors.subtype) |
| cond_query = " ".join(cond_parts) |
|
|
| params: dict = { |
| "query.cond": cond_query, |
| "pageSize": 50, |
| } |
|
|
| |
| if anchors.interventions: |
| params["query.intr"] = anchors.interventions[0] |
|
|
| |
| if anchors.eligibility_keywords: |
| params["query.term"] = " ".join(anchors.eligibility_keywords) |
| elif anchors.biomarkers: |
| params["query.term"] = anchors.biomarkers[0].split()[0] |
|
|
| if anchors.geography: |
| params["query.locn"] = anchors.geography.country |
|
|
| def _fetch() -> list[dict]: |
| resp = requests.get( |
| f"{_CT_GOV_BASE}/studies", |
| params=params, |
| headers=_CT_GOV_HEADERS, |
| timeout=30, |
| ) |
| resp.raise_for_status() |
| return resp.json().get("studies", []) |
|
|
| last_error = None |
| for attempt in range(_MAX_RETRIES): |
| try: |
| start = time.monotonic() |
| result = await asyncio.to_thread(_fetch) |
| elapsed = time.monotonic() - start |
|
|
| logger.info( |
| "ct_gov_search", |
| query=cond_query, |
| params={k: v for k, v in params.items() if k.startswith("query.")}, |
| result_count=len(result), |
| duration_s=round(elapsed, 2), |
| attempt=attempt + 1, |
| ) |
| return result |
| except (requests.ConnectionError, requests.Timeout) as e: |
| last_error = e |
| if attempt < _MAX_RETRIES - 1: |
| wait = _RETRY_BACKOFF_BASE**attempt |
| logger.warning( |
| "ct_gov_search_retry", |
| query=cond_query, |
| attempt=attempt + 1, |
| wait_s=round(wait, 1), |
| error=str(e)[:200], |
| ) |
| await asyncio.sleep(wait) |
| continue |
| raise |
| raise last_error |
|
|
| async def search_multi_variant(self, anchors: SearchAnchors) -> list[dict]: |
| """Fire multiple search variants in parallel for broader recall. |
| |
| Variants: |
| 1. Full query (condition + interventions + eligibility keywords) |
| 2. Condition-only query -- broader recall |
| 3. Per-biomarker queries (top 2) -- catches niche trials |
| 4. Per-intervention queries -- finds drug-specific trials |
| 5. Condition + eligibility keywords -- pre-filters by clinical features |
| |
| Results are merged and deduplicated by NCT ID. |
| """ |
| queries = [] |
|
|
| |
| queries.append(self.search_direct(anchors)) |
|
|
| |
| broad_anchors = SearchAnchors(condition=anchors.condition) |
| if anchors.geography: |
| broad_anchors.geography = anchors.geography |
| queries.append(self.search_direct(broad_anchors)) |
|
|
| |
| for biomarker in (anchors.biomarkers or [])[:2]: |
| bio_anchors = SearchAnchors( |
| condition=anchors.condition, |
| biomarkers=[biomarker], |
| ) |
| if anchors.geography: |
| bio_anchors.geography = anchors.geography |
| queries.append(self.search_direct(bio_anchors)) |
|
|
| |
| for intervention in (anchors.interventions or [])[:3]: |
| intr_anchors = SearchAnchors( |
| condition=anchors.condition, |
| interventions=[intervention], |
| ) |
| if anchors.geography: |
| intr_anchors.geography = anchors.geography |
| queries.append(self.search_direct(intr_anchors)) |
|
|
| |
| if anchors.eligibility_keywords: |
| elig_anchors = SearchAnchors( |
| condition=anchors.condition, |
| eligibility_keywords=anchors.eligibility_keywords, |
| ) |
| if anchors.geography: |
| elig_anchors.geography = anchors.geography |
| queries.append(self.search_direct(elig_anchors)) |
|
|
| |
| results = await asyncio.gather(*queries, return_exceptions=True) |
|
|
| |
| seen_nct_ids: set[str] = set() |
| merged: list[dict] = [] |
| for result in results: |
| if isinstance(result, Exception): |
| logger.warning("search_variant_failed", error=str(result)[:200]) |
| continue |
| for study in result: |
| nct_id = ( |
| study.get("protocolSection", {}) |
| .get("identificationModule", {}) |
| .get("nctId", "") |
| ) |
| if nct_id and nct_id not in seen_nct_ids: |
| seen_nct_ids.add(nct_id) |
| merged.append(study) |
|
|
| logger.info( |
| "multi_variant_search_complete", |
| variant_count=len(queries), |
| merged_count=len(merged), |
| ) |
| return merged |
|
|
| async def get_study_direct(self, nct_id: str) -> dict: |
| """Fetch study details directly from ClinicalTrials.gov API v2.""" |
| import requests |
|
|
| def _fetch() -> dict: |
| resp = requests.get( |
| f"{_CT_GOV_BASE}/studies/{nct_id}", |
| headers=_CT_GOV_HEADERS, |
| timeout=30, |
| ) |
| resp.raise_for_status() |
| return resp.json() |
|
|
| return await asyncio.to_thread(_fetch) |
|
|
|
|
| def _parse_age(age_str: str) -> int | None: |
| """Parse age string like '18 Years' into integer.""" |
| if not age_str: |
| return None |
| parts = age_str.strip().split() |
| try: |
| return int(parts[0]) |
| except (ValueError, IndexError): |
| return None |
|
|