TrialPath / trialpath /services /mcp_client.py
yakilee's picture
fix: correct logger variable and test mocks for enhanced search_direct
008813e
"""ClinicalTrials MCP server client wrapper."""
import asyncio
import json
import time
import httpx
import structlog
logger = structlog.get_logger("trialpath.mcp")
from trialpath.config import MCP_URL
from trialpath.models.search_anchors import SearchAnchors
from trialpath.models.trial_candidate import (
AgeRange,
EligibilityText,
TrialCandidate,
TrialLocation,
)
_MAX_RETRIES = 3
_RETRY_BACKOFF_BASE = 2.0
_CT_GOV_BASE = "https://clinicaltrials.gov/api/v2"
_CT_GOV_HEADERS = {
"User-Agent": "TrialPath/0.1 (clinical trial matching; +https://github.com/trialpath)",
}
class MCPError(Exception):
"""Error returned by the MCP server."""
def __init__(self, code: int, message: str):
self.code = code
self.message = message
super().__init__(f"MCP Error {code}: {message}")
class ClinicalTrialsMCPClient:
"""Client for ClinicalTrials MCP Server."""
def __init__(self, mcp_url: str | None = None):
self.mcp_url = mcp_url or MCP_URL
@staticmethod
def normalize_trial(raw: dict) -> TrialCandidate:
"""Convert raw API response to TrialCandidate model."""
protocol = raw.get("protocolSection", {})
id_module = protocol.get("identificationModule", {})
status_module = protocol.get("statusModule", {})
design_module = protocol.get("designModule", {})
eligibility_module = protocol.get("eligibilityModule", {})
conditions_module = protocol.get("conditionsModule", {})
contacts_module = protocol.get("contactsLocationsModule", {})
nct_id = id_module.get("nctId", raw.get("nctId", ""))
title = id_module.get("briefTitle", raw.get("title", ""))
conditions = conditions_module.get("conditions", raw.get("conditions", []))
phase_list = design_module.get("phases", [])
phase = phase_list[0] if phase_list else raw.get("phase")
status = status_module.get("overallStatus", raw.get("status"))
locations: list[TrialLocation] = []
for loc in contacts_module.get("locations", []):
locations.append(
TrialLocation(
country=loc.get("country", ""),
city=loc.get("city"),
)
)
age_range = None
min_age_str = eligibility_module.get("minimumAge", "")
max_age_str = eligibility_module.get("maximumAge", "")
if min_age_str or max_age_str:
min_age = _parse_age(min_age_str)
max_age = _parse_age(max_age_str)
if min_age is not None or max_age is not None:
age_range = AgeRange(min=min_age, max=max_age)
eligibility_text = None
criteria_text = eligibility_module.get("eligibilityCriteria", "")
if criteria_text:
parts = criteria_text.split("Exclusion Criteria", 1)
inclusion = parts[0].replace("Inclusion Criteria:", "").strip()
exclusion = parts[1].strip() if len(parts) > 1 else ""
eligibility_text = EligibilityText(
inclusion=inclusion,
exclusion=exclusion,
)
fingerprint = f"{title} {' '.join(conditions)} {phase or ''}"
return TrialCandidate(
nct_id=nct_id,
title=title,
conditions=conditions,
phase=phase,
status=status,
locations=locations,
age_range=age_range,
fingerprint_text=fingerprint.strip(),
eligibility_text=eligibility_text,
)
async def search(self, anchors: SearchAnchors) -> list[dict]:
"""Convert SearchAnchors to MCP search_studies call."""
query_parts = [anchors.condition]
if anchors.subtype:
query_parts.append(anchors.subtype)
if anchors.biomarkers:
query_parts.extend(anchors.biomarkers)
if anchors.interventions:
query_parts.extend(anchors.interventions[:2])
query = " ".join(query_parts)
filters = []
if anchors.trial_filters.recruitment_status:
status_filter = " OR ".join(
f"AREA[OverallStatus]{s}" for s in anchors.trial_filters.recruitment_status
)
filters.append(f"({status_filter})")
if anchors.trial_filters.phase:
phase_filter = " OR ".join(f"AREA[Phase]{p}" for p in anchors.trial_filters.phase)
filters.append(f"({phase_filter})")
if anchors.age is not None:
filters.append(f"AREA[MinimumAge]RANGE[MIN, {anchors.age}]")
filters.append(f"AREA[MaximumAge]RANGE[{anchors.age}, MAX]")
filter_str = " AND ".join(filters) if filters else None
params: dict = {
"query": query,
"pageSize": 50,
"sort": "LastUpdateDate:desc",
}
if filter_str:
params["filter"] = filter_str
if anchors.geography:
params["country"] = anchors.geography.country
result = await self._call_tool("clinicaltrials_search_studies", params)
return result.get("studies", [])
async def get_study(self, nct_id: str) -> dict:
"""Fetch full study details by NCT ID."""
result = await self._call_tool(
"clinicaltrials_get_study",
{
"nctIds": [nct_id],
"summaryOnly": False,
},
)
studies = result.get("studies", [])
return studies[0] if studies else {}
async def find_eligible(
self,
age: int,
sex: str,
conditions: list[str],
country: str,
max_results: int = 20,
) -> dict:
"""Use find_eligible_studies for demographic-based matching."""
return await self._call_tool(
"clinicaltrials_find_eligible_studies",
{
"age": age,
"sex": sex,
"conditions": conditions,
"location": {"country": country},
"recruitingOnly": True,
"maxResults": max_results,
},
)
async def compare_studies(self, nct_ids: list[str]) -> dict:
"""Compare 2-5 studies side by side."""
return await self._call_tool(
"clinicaltrials_compare_studies",
{
"nctIds": nct_ids,
"compareFields": "all",
},
)
@staticmethod
def _parse_sse_events(text: str) -> dict:
"""Extract the last JSON-RPC data payload from an SSE event stream."""
last_data: dict = {}
for line in text.splitlines():
if line.startswith("data: "):
payload = line[len("data: ") :]
try:
last_data = json.loads(payload)
except json.JSONDecodeError:
continue
return last_data
async def _call_tool(self, tool_name: str, params: dict) -> dict:
"""Call an MCP tool via JSON-RPC over Streamable HTTP with retry logic.
Handles both plain JSON and SSE (text/event-stream) responses
for compatibility with MCP Streamable HTTP transport.
"""
last_error = None
for attempt in range(_MAX_RETRIES):
start = time.monotonic()
try:
async with httpx.AsyncClient(timeout=httpx.Timeout(120.0, connect=10.0)) as client:
response = await client.post(
f"{self.mcp_url}/mcp",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": params,
},
"id": 1,
},
headers={
"Accept": "application/json, text/event-stream",
},
)
response.raise_for_status()
text = response.text
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
data = self._parse_sse_events(text)
else:
data = json.loads(text)
if not data:
raise MCPError(code=-1, message="Empty response from MCP server")
if "error" in data:
raise MCPError(
code=data["error"].get("code", -1),
message=data["error"].get("message", "Unknown MCP error"),
)
elapsed = time.monotonic() - start
logger.info(
"mcp_tool_call",
tool=tool_name,
attempt=attempt + 1,
duration_s=round(elapsed, 2),
)
result = data.get("result", {})
return self._extract_mcp_content(result)
except (httpx.ConnectError, httpx.TimeoutException) as e:
last_error = e
if attempt < _MAX_RETRIES - 1:
await asyncio.sleep(_RETRY_BACKOFF_BASE**attempt)
continue
raise
assert last_error is not None # pragma: no cover
raise last_error # pragma: no cover
@staticmethod
def _extract_mcp_content(result: dict) -> dict:
"""Extract actual data from MCP tool result content wrapper.
MCP wraps tool results as:
{"content": [{"type": "text", "text": "...json..."}]}
"""
content_list = result.get("content", [])
for item in content_list:
if item.get("type") == "text":
try:
return json.loads(item["text"])
except (json.JSONDecodeError, KeyError):
continue
return result
async def search_direct(self, anchors: SearchAnchors) -> list[dict]:
"""Search ClinicalTrials.gov API v2 directly with retry.
Uses structured query parameters:
- query.cond: condition (+ subtype)
- query.intr: first intervention (drug name)
- query.term: eligibility keywords (ECOG, stage, biomarkers)
Falls back to query.term-only when new fields are absent (backward
compatible).
"""
import requests
# Build condition query (query.cond)
cond_parts = [anchors.condition]
if anchors.subtype:
cond_parts.append(anchors.subtype)
cond_query = " ".join(cond_parts)
params: dict = {
"query.cond": cond_query,
"pageSize": 50,
}
# Intervention query (query.intr) — first drug name
if anchors.interventions:
params["query.intr"] = anchors.interventions[0]
# General term query (query.term) — eligibility keywords or biomarker fallback
if anchors.eligibility_keywords:
params["query.term"] = " ".join(anchors.eligibility_keywords)
elif anchors.biomarkers:
params["query.term"] = anchors.biomarkers[0].split()[0]
if anchors.geography:
params["query.locn"] = anchors.geography.country
def _fetch() -> list[dict]:
resp = requests.get(
f"{_CT_GOV_BASE}/studies",
params=params,
headers=_CT_GOV_HEADERS,
timeout=30,
)
resp.raise_for_status()
return resp.json().get("studies", [])
last_error = None
for attempt in range(_MAX_RETRIES):
try:
start = time.monotonic()
result = await asyncio.to_thread(_fetch)
elapsed = time.monotonic() - start
logger.info(
"ct_gov_search",
query=cond_query,
params={k: v for k, v in params.items() if k.startswith("query.")},
result_count=len(result),
duration_s=round(elapsed, 2),
attempt=attempt + 1,
)
return result
except (requests.ConnectionError, requests.Timeout) as e:
last_error = e
if attempt < _MAX_RETRIES - 1:
wait = _RETRY_BACKOFF_BASE**attempt
logger.warning(
"ct_gov_search_retry",
query=cond_query,
attempt=attempt + 1,
wait_s=round(wait, 1),
error=str(e)[:200],
)
await asyncio.sleep(wait)
continue
raise
raise last_error # pragma: no cover
async def search_multi_variant(self, anchors: SearchAnchors) -> list[dict]:
"""Fire multiple search variants in parallel for broader recall.
Variants:
1. Full query (condition + interventions + eligibility keywords)
2. Condition-only query -- broader recall
3. Per-biomarker queries (top 2) -- catches niche trials
4. Per-intervention queries -- finds drug-specific trials
5. Condition + eligibility keywords -- pre-filters by clinical features
Results are merged and deduplicated by NCT ID.
"""
queries = []
# Variant 1: Full query (existing behavior)
queries.append(self.search_direct(anchors))
# Variant 2: Condition-only (broader recall)
broad_anchors = SearchAnchors(condition=anchors.condition)
if anchors.geography:
broad_anchors.geography = anchors.geography
queries.append(self.search_direct(broad_anchors))
# Variant 3: Per-biomarker queries (top 2 biomarkers only)
for biomarker in (anchors.biomarkers or [])[:2]:
bio_anchors = SearchAnchors(
condition=anchors.condition,
biomarkers=[biomarker],
)
if anchors.geography:
bio_anchors.geography = anchors.geography
queries.append(self.search_direct(bio_anchors))
# Variant 4: Per-intervention queries (drug-specific trials)
for intervention in (anchors.interventions or [])[:3]:
intr_anchors = SearchAnchors(
condition=anchors.condition,
interventions=[intervention],
)
if anchors.geography:
intr_anchors.geography = anchors.geography
queries.append(self.search_direct(intr_anchors))
# Variant 5: Condition + eligibility keywords (clinical feature pre-filter)
if anchors.eligibility_keywords:
elig_anchors = SearchAnchors(
condition=anchors.condition,
eligibility_keywords=anchors.eligibility_keywords,
)
if anchors.geography:
elig_anchors.geography = anchors.geography
queries.append(self.search_direct(elig_anchors))
# Fire all variants in parallel
results = await asyncio.gather(*queries, return_exceptions=True)
# Merge and deduplicate by NCT ID
seen_nct_ids: set[str] = set()
merged: list[dict] = []
for result in results:
if isinstance(result, Exception):
logger.warning("search_variant_failed", error=str(result)[:200])
continue
for study in result:
nct_id = (
study.get("protocolSection", {})
.get("identificationModule", {})
.get("nctId", "")
)
if nct_id and nct_id not in seen_nct_ids:
seen_nct_ids.add(nct_id)
merged.append(study)
logger.info(
"multi_variant_search_complete",
variant_count=len(queries),
merged_count=len(merged),
)
return merged
async def get_study_direct(self, nct_id: str) -> dict:
"""Fetch study details directly from ClinicalTrials.gov API v2."""
import requests
def _fetch() -> dict:
resp = requests.get(
f"{_CT_GOV_BASE}/studies/{nct_id}",
headers=_CT_GOV_HEADERS,
timeout=30,
)
resp.raise_for_status()
return resp.json()
return await asyncio.to_thread(_fetch)
def _parse_age(age_str: str) -> int | None:
"""Parse age string like '18 Years' into integer."""
if not age_str:
return None
parts = age_str.strip().split()
try:
return int(parts[0])
except (ValueError, IndexError):
return None