feat: implement 3 BE service stubs with 17 TDD tests
Browse files- MedGemmaExtractor: profile parsing, extraction prompts, criterion evaluation
- GeminiPlanner: SearchAnchors generation, eligibility evaluation via google-genai
- ClinicalTrialsMCPClient: MCP JSON-RPC wrapper for search, get_study, find_eligible
- MCPError exception class for error handling
54 total BE tests pass (37 models + 17 services). Ruff clean.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- trialpath/services/__init__.py +1 -0
- trialpath/services/gemini_planner.py +81 -0
- trialpath/services/mcp_client.py +123 -0
- trialpath/services/medgemma_extractor.py +119 -0
- trialpath/tests/test_gemini.py +124 -0
- trialpath/tests/test_mcp.py +134 -0
- trialpath/tests/test_medgemma.py +82 -0
trialpath/services/__init__.py
CHANGED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""TrialPath backend services."""
|
trialpath/services/gemini_planner.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Gemini structured output for SearchAnchors generation and eligibility evaluation."""
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from google import genai
|
| 6 |
+
|
| 7 |
+
from trialpath.models.eligibility_ledger import EligibilityLedger
|
| 8 |
+
from trialpath.models.search_anchors import SearchAnchors
|
| 9 |
+
|
| 10 |
+
MODEL = "gemini-3-pro"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GeminiPlanner:
|
| 14 |
+
"""Orchestration layer using Gemini for structured output."""
|
| 15 |
+
|
| 16 |
+
def __init__(self, model: str = MODEL):
|
| 17 |
+
self.model = model
|
| 18 |
+
self.client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY", ""))
|
| 19 |
+
|
| 20 |
+
async def generate_search_anchors(self, patient_profile: dict) -> SearchAnchors:
|
| 21 |
+
"""Use Gemini structured output to generate SearchAnchors from PatientProfile."""
|
| 22 |
+
prompt = f"""
|
| 23 |
+
Given the following patient profile, generate search parameters
|
| 24 |
+
for finding relevant NSCLC clinical trials on ClinicalTrials.gov.
|
| 25 |
+
|
| 26 |
+
Patient Profile:
|
| 27 |
+
{json.dumps(patient_profile, indent=2)}
|
| 28 |
+
|
| 29 |
+
Generate SearchAnchors that:
|
| 30 |
+
1. Focus on the patient's specific cancer type, stage, and biomarkers
|
| 31 |
+
2. Include appropriate geographic filters
|
| 32 |
+
3. Consider the patient's age and performance status
|
| 33 |
+
4. Set a relaxation_order for broadening search if too few results
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
response = self.client.models.generate_content(
|
| 37 |
+
model=self.model,
|
| 38 |
+
contents=prompt,
|
| 39 |
+
config={
|
| 40 |
+
"response_mime_type": "application/json",
|
| 41 |
+
"response_json_schema": SearchAnchors.model_json_schema(),
|
| 42 |
+
},
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return SearchAnchors.model_validate_json(response.text)
|
| 46 |
+
|
| 47 |
+
async def evaluate_eligibility(
|
| 48 |
+
self,
|
| 49 |
+
patient_profile: dict,
|
| 50 |
+
trial_candidate: dict,
|
| 51 |
+
search_log: object | None = None,
|
| 52 |
+
) -> EligibilityLedger:
|
| 53 |
+
"""Use Gemini to evaluate eligibility for a single trial."""
|
| 54 |
+
prompt = f"""
|
| 55 |
+
Evaluate this patient's eligibility for the clinical trial below.
|
| 56 |
+
|
| 57 |
+
For each inclusion/exclusion criterion:
|
| 58 |
+
1. Assign a criterion_id (inc_1, inc_2, ... or exc_1, exc_2, ...)
|
| 59 |
+
2. Determine if the criterion is met, not_met, or unknown
|
| 60 |
+
3. Provide reasoning and evidence pointers
|
| 61 |
+
|
| 62 |
+
Patient Profile:
|
| 63 |
+
{json.dumps(patient_profile, indent=2, default=str)}
|
| 64 |
+
|
| 65 |
+
Trial:
|
| 66 |
+
{json.dumps(trial_candidate, indent=2, default=str)}
|
| 67 |
+
|
| 68 |
+
Also identify gaps: criteria that are 'unknown' where additional data
|
| 69 |
+
could change the assessment.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
response = self.client.models.generate_content(
|
| 73 |
+
model=self.model,
|
| 74 |
+
contents=prompt,
|
| 75 |
+
config={
|
| 76 |
+
"response_mime_type": "application/json",
|
| 77 |
+
"response_json_schema": EligibilityLedger.model_json_schema(),
|
| 78 |
+
},
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return EligibilityLedger.model_validate_json(response.text)
|
trialpath/services/mcp_client.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ClinicalTrials MCP server client wrapper."""
|
| 2 |
+
import httpx
|
| 3 |
+
|
| 4 |
+
from trialpath.models.search_anchors import SearchAnchors
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MCPError(Exception):
|
| 8 |
+
"""Error returned by the MCP server."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, code: int, message: str):
|
| 11 |
+
self.code = code
|
| 12 |
+
self.message = message
|
| 13 |
+
super().__init__(f"MCP Error {code}: {message}")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ClinicalTrialsMCPClient:
|
| 17 |
+
"""Client for ClinicalTrials MCP Server."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, mcp_url: str = "http://localhost:3000"):
|
| 20 |
+
self.mcp_url = mcp_url
|
| 21 |
+
|
| 22 |
+
async def search(self, anchors: SearchAnchors) -> list[dict]:
|
| 23 |
+
"""Convert SearchAnchors to MCP search_studies call."""
|
| 24 |
+
query_parts = [anchors.condition]
|
| 25 |
+
if anchors.subtype:
|
| 26 |
+
query_parts.append(anchors.subtype)
|
| 27 |
+
if anchors.biomarkers:
|
| 28 |
+
query_parts.extend(anchors.biomarkers)
|
| 29 |
+
|
| 30 |
+
query = " ".join(query_parts)
|
| 31 |
+
|
| 32 |
+
filters = []
|
| 33 |
+
if anchors.trial_filters.recruitment_status:
|
| 34 |
+
status_filter = " OR ".join(
|
| 35 |
+
f"AREA[OverallStatus]{s}"
|
| 36 |
+
for s in anchors.trial_filters.recruitment_status
|
| 37 |
+
)
|
| 38 |
+
filters.append(f"({status_filter})")
|
| 39 |
+
|
| 40 |
+
if anchors.trial_filters.phase:
|
| 41 |
+
phase_filter = " OR ".join(
|
| 42 |
+
f"AREA[Phase]{p}" for p in anchors.trial_filters.phase
|
| 43 |
+
)
|
| 44 |
+
filters.append(f"({phase_filter})")
|
| 45 |
+
|
| 46 |
+
if anchors.age is not None:
|
| 47 |
+
filters.append(f"AREA[MinimumAge]RANGE[MIN, {anchors.age}]")
|
| 48 |
+
filters.append(f"AREA[MaximumAge]RANGE[{anchors.age}, MAX]")
|
| 49 |
+
|
| 50 |
+
filter_str = " AND ".join(filters) if filters else None
|
| 51 |
+
|
| 52 |
+
params: dict = {
|
| 53 |
+
"query": query,
|
| 54 |
+
"pageSize": 50,
|
| 55 |
+
"sort": "LastUpdateDate:desc",
|
| 56 |
+
}
|
| 57 |
+
if filter_str:
|
| 58 |
+
params["filter"] = filter_str
|
| 59 |
+
if anchors.geography:
|
| 60 |
+
params["country"] = anchors.geography.country
|
| 61 |
+
|
| 62 |
+
result = await self._call_tool("clinicaltrials_search_studies", params)
|
| 63 |
+
return result.get("studies", [])
|
| 64 |
+
|
| 65 |
+
async def get_study(self, nct_id: str) -> dict:
|
| 66 |
+
"""Fetch full study details by NCT ID."""
|
| 67 |
+
result = await self._call_tool("clinicaltrials_get_study", {
|
| 68 |
+
"nctIds": [nct_id],
|
| 69 |
+
"summaryOnly": False,
|
| 70 |
+
})
|
| 71 |
+
studies = result.get("studies", [])
|
| 72 |
+
return studies[0] if studies else {}
|
| 73 |
+
|
| 74 |
+
async def find_eligible(
|
| 75 |
+
self,
|
| 76 |
+
age: int,
|
| 77 |
+
sex: str,
|
| 78 |
+
conditions: list[str],
|
| 79 |
+
country: str,
|
| 80 |
+
max_results: int = 20,
|
| 81 |
+
) -> dict:
|
| 82 |
+
"""Use find_eligible_studies for demographic-based matching."""
|
| 83 |
+
return await self._call_tool("clinicaltrials_find_eligible_studies", {
|
| 84 |
+
"age": age,
|
| 85 |
+
"sex": sex,
|
| 86 |
+
"conditions": conditions,
|
| 87 |
+
"location": {"country": country},
|
| 88 |
+
"recruitingOnly": True,
|
| 89 |
+
"maxResults": max_results,
|
| 90 |
+
})
|
| 91 |
+
|
| 92 |
+
async def compare_studies(self, nct_ids: list[str]) -> dict:
|
| 93 |
+
"""Compare 2-5 studies side by side."""
|
| 94 |
+
return await self._call_tool("clinicaltrials_compare_studies", {
|
| 95 |
+
"nctIds": nct_ids,
|
| 96 |
+
"compareFields": "all",
|
| 97 |
+
})
|
| 98 |
+
|
| 99 |
+
async def _call_tool(self, tool_name: str, params: dict) -> dict:
|
| 100 |
+
"""Call an MCP tool via JSON-RPC."""
|
| 101 |
+
async with httpx.AsyncClient(timeout=30.0) as client:
|
| 102 |
+
response = await client.post(
|
| 103 |
+
f"{self.mcp_url}/mcp/v1/tools/call",
|
| 104 |
+
json={
|
| 105 |
+
"jsonrpc": "2.0",
|
| 106 |
+
"method": "tools/call",
|
| 107 |
+
"params": {
|
| 108 |
+
"name": tool_name,
|
| 109 |
+
"arguments": params,
|
| 110 |
+
},
|
| 111 |
+
"id": 1,
|
| 112 |
+
},
|
| 113 |
+
)
|
| 114 |
+
response.raise_for_status()
|
| 115 |
+
data = response.json()
|
| 116 |
+
|
| 117 |
+
if "error" in data:
|
| 118 |
+
raise MCPError(
|
| 119 |
+
code=data["error"].get("code", -1),
|
| 120 |
+
message=data["error"].get("message", "Unknown MCP error"),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
return data.get("result", {})
|
trialpath/services/medgemma_extractor.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MedGemma HF endpoint integration for patient profile extraction."""
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MedGemmaExtractor:
|
| 7 |
+
"""Extract patient profiles from medical documents using MedGemma.
|
| 8 |
+
|
| 9 |
+
For PoC, this is interface-first: the parsing and prompt methods are
|
| 10 |
+
fully implemented, but the actual model call requires a HuggingFace
|
| 11 |
+
Inference Endpoint or local GPU with MedGemma loaded.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, endpoint_url: str | None = None, hf_token: str | None = None):
|
| 15 |
+
self.endpoint_url = endpoint_url
|
| 16 |
+
self.hf_token = hf_token
|
| 17 |
+
self.pipe = None # Initialized lazily when model is available
|
| 18 |
+
|
| 19 |
+
def _system_prompt(self) -> str:
|
| 20 |
+
return (
|
| 21 |
+
"You are an expert medical data extractor specializing in oncology. "
|
| 22 |
+
"Extract structured patient information from medical documents. "
|
| 23 |
+
"Always cite the source document and location for each extracted fact. "
|
| 24 |
+
"If information is unclear or missing, explicitly note it as unknown."
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def _build_extraction_prompt(self, metadata: dict) -> str:
|
| 28 |
+
return f"""
|
| 29 |
+
Extract a structured patient profile from the following medical documents.
|
| 30 |
+
|
| 31 |
+
Known metadata: age={metadata.get('age', 'unknown')}, sex={metadata.get('sex', 'unknown')}
|
| 32 |
+
|
| 33 |
+
Extract the following fields in JSON format:
|
| 34 |
+
- diagnosis (primary_condition, histology, stage, diagnosis_date)
|
| 35 |
+
- performance_status (scale, value, evidence)
|
| 36 |
+
- biomarkers (name, result, date, evidence for each)
|
| 37 |
+
- key_labs (name, value, unit, date, evidence for each)
|
| 38 |
+
- treatments (drug_name, start_date, end_date, line, evidence)
|
| 39 |
+
- comorbidities (name, grade, evidence)
|
| 40 |
+
- imaging_summary (modality, date, finding, interpretation, certainty, evidence)
|
| 41 |
+
- unknowns (field, reason, importance for each missing critical field)
|
| 42 |
+
|
| 43 |
+
For each evidence reference, include: doc_id (filename), page number, span_id.
|
| 44 |
+
|
| 45 |
+
Return ONLY valid JSON matching the PatientProfile schema.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def _parse_profile(self, raw_text: str, metadata: dict) -> dict:
|
| 49 |
+
"""Parse MedGemma output into PatientProfile structure."""
|
| 50 |
+
try:
|
| 51 |
+
profile = json.loads(raw_text)
|
| 52 |
+
except json.JSONDecodeError:
|
| 53 |
+
json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", raw_text, re.DOTALL)
|
| 54 |
+
if json_match:
|
| 55 |
+
profile = json.loads(json_match.group(1))
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(
|
| 58 |
+
f"Could not parse MedGemma output as JSON: {raw_text[:200]}"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if "demographics" not in profile:
|
| 62 |
+
profile["demographics"] = {}
|
| 63 |
+
profile["demographics"].update(metadata)
|
| 64 |
+
|
| 65 |
+
return profile
|
| 66 |
+
|
| 67 |
+
async def extract(self, document_urls: list[str], metadata: dict) -> dict:
|
| 68 |
+
"""Extract PatientProfile from documents via MedGemma.
|
| 69 |
+
|
| 70 |
+
Requires self.pipe to be initialized with a HuggingFace pipeline.
|
| 71 |
+
"""
|
| 72 |
+
if self.pipe is None:
|
| 73 |
+
raise RuntimeError(
|
| 74 |
+
"MedGemma pipeline not initialized. "
|
| 75 |
+
"Set up a HF Inference Endpoint or load model locally."
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
content = [
|
| 79 |
+
{"type": "text", "text": self._build_extraction_prompt(metadata)},
|
| 80 |
+
]
|
| 81 |
+
# In production, images would be loaded from document_urls
|
| 82 |
+
messages = [
|
| 83 |
+
{"role": "system", "content": [{"type": "text", "text": self._system_prompt()}]},
|
| 84 |
+
{"role": "user", "content": content},
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
output = self.pipe(text=messages, max_new_tokens=2048)
|
| 88 |
+
raw_text = output[0]["generated_text"][-1]["content"]
|
| 89 |
+
return self._parse_profile(raw_text, metadata)
|
| 90 |
+
|
| 91 |
+
async def evaluate_medical_criterion(
|
| 92 |
+
self,
|
| 93 |
+
criterion_text: str,
|
| 94 |
+
patient_profile: object,
|
| 95 |
+
evidence_docs: list,
|
| 96 |
+
) -> dict:
|
| 97 |
+
"""Evaluate a single medical criterion against patient evidence.
|
| 98 |
+
|
| 99 |
+
Stub for PoC -- requires MedGemma pipeline.
|
| 100 |
+
"""
|
| 101 |
+
if self.pipe is None:
|
| 102 |
+
raise RuntimeError("MedGemma pipeline not initialized.")
|
| 103 |
+
|
| 104 |
+
prompt = f"""
|
| 105 |
+
Evaluate whether the patient meets this clinical trial criterion:
|
| 106 |
+
CRITERION: {criterion_text}
|
| 107 |
+
|
| 108 |
+
Respond with JSON:
|
| 109 |
+
{{"decision": "met|not_met|unknown", "reasoning": "...", "confidence": 0.0-1.0}}
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
messages = [
|
| 113 |
+
{"role": "system", "content": [{"type": "text", "text": self._system_prompt()}]},
|
| 114 |
+
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
| 115 |
+
]
|
| 116 |
+
|
| 117 |
+
output = self.pipe(text=messages, max_new_tokens=1024)
|
| 118 |
+
raw_text = output[0]["generated_text"][-1]["content"]
|
| 119 |
+
return json.loads(raw_text)
|
trialpath/tests/test_gemini.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TDD tests for Gemini planner service."""
|
| 2 |
+
from unittest.mock import MagicMock, patch
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from trialpath.models.eligibility_ledger import EligibilityLedger, OverallAssessment
|
| 7 |
+
from trialpath.models.search_anchors import SearchAnchors
|
| 8 |
+
from trialpath.services.gemini_planner import GeminiPlanner
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestGeminiSearchAnchorsGeneration:
|
| 12 |
+
"""Test Gemini structured output for SearchAnchors generation."""
|
| 13 |
+
|
| 14 |
+
@pytest.fixture
|
| 15 |
+
def sample_profile(self):
|
| 16 |
+
return {
|
| 17 |
+
"patient_id": "P001",
|
| 18 |
+
"demographics": {"age": 52, "sex": "female"},
|
| 19 |
+
"diagnosis": {
|
| 20 |
+
"primary_condition": "Non-Small Cell Lung Cancer",
|
| 21 |
+
"histology": "adenocarcinoma",
|
| 22 |
+
"stage": "IVa",
|
| 23 |
+
},
|
| 24 |
+
"biomarkers": [
|
| 25 |
+
{"name": "EGFR", "result": "Exon 19 deletion"},
|
| 26 |
+
],
|
| 27 |
+
"performance_status": {"scale": "ECOG", "value": 1},
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
@pytest.mark.asyncio
|
| 31 |
+
async def test_search_anchors_has_correct_condition(self, sample_profile):
|
| 32 |
+
"""Generated SearchAnchors should reference NSCLC."""
|
| 33 |
+
with patch("google.genai.Client") as MockClient:
|
| 34 |
+
mock_response = MagicMock()
|
| 35 |
+
mock_response.text = SearchAnchors(
|
| 36 |
+
condition="Non-Small Cell Lung Cancer",
|
| 37 |
+
subtype="adenocarcinoma",
|
| 38 |
+
biomarkers=["EGFR exon 19 deletion"],
|
| 39 |
+
stage="IV",
|
| 40 |
+
age=52,
|
| 41 |
+
performance_status_max=1,
|
| 42 |
+
).model_dump_json()
|
| 43 |
+
|
| 44 |
+
MockClient.return_value.models.generate_content = MagicMock(
|
| 45 |
+
return_value=mock_response
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
planner = GeminiPlanner()
|
| 49 |
+
anchors = await planner.generate_search_anchors(sample_profile)
|
| 50 |
+
|
| 51 |
+
assert "lung" in anchors.condition.lower() or "nsclc" in anchors.condition.lower()
|
| 52 |
+
assert anchors.age == 52
|
| 53 |
+
|
| 54 |
+
@pytest.mark.asyncio
|
| 55 |
+
async def test_search_anchors_includes_biomarkers(self, sample_profile):
|
| 56 |
+
"""SearchAnchors should include patient biomarkers."""
|
| 57 |
+
with patch("google.genai.Client") as MockClient:
|
| 58 |
+
mock_response = MagicMock()
|
| 59 |
+
mock_response.text = SearchAnchors(
|
| 60 |
+
condition="NSCLC",
|
| 61 |
+
biomarkers=["EGFR exon 19 deletion"],
|
| 62 |
+
).model_dump_json()
|
| 63 |
+
|
| 64 |
+
MockClient.return_value.models.generate_content = MagicMock(
|
| 65 |
+
return_value=mock_response
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
planner = GeminiPlanner()
|
| 69 |
+
anchors = await planner.generate_search_anchors(sample_profile)
|
| 70 |
+
|
| 71 |
+
assert len(anchors.biomarkers) > 0
|
| 72 |
+
assert any("EGFR" in b for b in anchors.biomarkers)
|
| 73 |
+
|
| 74 |
+
@pytest.mark.asyncio
|
| 75 |
+
async def test_search_anchors_json_schema_passed(self, sample_profile):
|
| 76 |
+
"""Verify that Gemini is called with response_json_schema."""
|
| 77 |
+
with patch("google.genai.Client") as MockClient:
|
| 78 |
+
mock_response = MagicMock()
|
| 79 |
+
mock_response.text = SearchAnchors(condition="NSCLC").model_dump_json()
|
| 80 |
+
|
| 81 |
+
mock_generate = MagicMock(return_value=mock_response)
|
| 82 |
+
MockClient.return_value.models.generate_content = mock_generate
|
| 83 |
+
|
| 84 |
+
planner = GeminiPlanner()
|
| 85 |
+
await planner.generate_search_anchors(sample_profile)
|
| 86 |
+
|
| 87 |
+
call_args = mock_generate.call_args
|
| 88 |
+
config = call_args.kwargs.get("config", call_args[1].get("config", {}))
|
| 89 |
+
assert config.get("response_mime_type") == "application/json"
|
| 90 |
+
assert "response_json_schema" in config
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class TestGeminiEligibilityEvaluation:
|
| 94 |
+
"""Test Gemini eligibility evaluation output."""
|
| 95 |
+
|
| 96 |
+
@pytest.mark.asyncio
|
| 97 |
+
async def test_ledger_has_all_required_fields(self):
|
| 98 |
+
"""EligibilityLedger from Gemini should have patient_id, nct_id, assessment."""
|
| 99 |
+
mock_ledger = EligibilityLedger(
|
| 100 |
+
patient_id="P001",
|
| 101 |
+
nct_id="NCT01234567",
|
| 102 |
+
overall_assessment=OverallAssessment.UNCERTAIN,
|
| 103 |
+
criteria=[],
|
| 104 |
+
gaps=[],
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
assert mock_ledger.patient_id == "P001"
|
| 108 |
+
assert mock_ledger.nct_id == "NCT01234567"
|
| 109 |
+
assert mock_ledger.overall_assessment in OverallAssessment
|
| 110 |
+
|
| 111 |
+
@pytest.mark.asyncio
|
| 112 |
+
async def test_error_handling_invalid_json(self):
|
| 113 |
+
"""Should raise error on invalid Gemini JSON response."""
|
| 114 |
+
with patch("google.genai.Client") as MockClient:
|
| 115 |
+
mock_response = MagicMock()
|
| 116 |
+
mock_response.text = "not valid json"
|
| 117 |
+
|
| 118 |
+
MockClient.return_value.models.generate_content = MagicMock(
|
| 119 |
+
return_value=mock_response
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
planner = GeminiPlanner()
|
| 123 |
+
with pytest.raises(Exception):
|
| 124 |
+
await planner.evaluate_eligibility({}, {}, None)
|
trialpath/tests/test_mcp.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TDD tests for ClinicalTrials MCP client."""
|
| 2 |
+
from unittest.mock import AsyncMock, MagicMock, patch
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
from trialpath.models.search_anchors import GeographyFilter, SearchAnchors, TrialFilters
|
| 7 |
+
from trialpath.services.mcp_client import ClinicalTrialsMCPClient, MCPError
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestMCPClient:
|
| 11 |
+
"""Test ClinicalTrials MCP client."""
|
| 12 |
+
|
| 13 |
+
@pytest.fixture
|
| 14 |
+
def client(self):
|
| 15 |
+
return ClinicalTrialsMCPClient(mcp_url="http://localhost:3000")
|
| 16 |
+
|
| 17 |
+
@pytest.fixture
|
| 18 |
+
def sample_anchors(self):
|
| 19 |
+
return SearchAnchors(
|
| 20 |
+
condition="Non-Small Cell Lung Cancer",
|
| 21 |
+
subtype="adenocarcinoma",
|
| 22 |
+
biomarkers=["EGFR exon 19 deletion"],
|
| 23 |
+
stage="IV",
|
| 24 |
+
age=52,
|
| 25 |
+
geography=GeographyFilter(country="United States"),
|
| 26 |
+
trial_filters=TrialFilters(
|
| 27 |
+
recruitment_status=["Recruiting"],
|
| 28 |
+
phase=["Phase 3"],
|
| 29 |
+
),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def _mock_httpx(self, MockHTTP, response_data):
|
| 33 |
+
mock_response = MagicMock()
|
| 34 |
+
mock_response.json.return_value = response_data
|
| 35 |
+
mock_response.raise_for_status = MagicMock()
|
| 36 |
+
|
| 37 |
+
mock_client = AsyncMock()
|
| 38 |
+
mock_client.post.return_value = mock_response
|
| 39 |
+
|
| 40 |
+
mock_ctx = MagicMock()
|
| 41 |
+
mock_ctx.__aenter__ = AsyncMock(return_value=mock_client)
|
| 42 |
+
mock_ctx.__aexit__ = AsyncMock(return_value=None)
|
| 43 |
+
MockHTTP.return_value = mock_ctx
|
| 44 |
+
return mock_client
|
| 45 |
+
|
| 46 |
+
@pytest.mark.asyncio
|
| 47 |
+
async def test_search_builds_correct_query(self, client, sample_anchors):
|
| 48 |
+
"""Search should combine condition, subtype, and biomarkers into query."""
|
| 49 |
+
with patch("httpx.AsyncClient") as MockHTTP:
|
| 50 |
+
mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}})
|
| 51 |
+
|
| 52 |
+
await client.search(sample_anchors)
|
| 53 |
+
|
| 54 |
+
call_args = mock_client.post.call_args
|
| 55 |
+
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 56 |
+
query = body["params"]["arguments"]["query"]
|
| 57 |
+
|
| 58 |
+
assert "Non-Small Cell Lung Cancer" in query
|
| 59 |
+
assert "adenocarcinoma" in query
|
| 60 |
+
|
| 61 |
+
@pytest.mark.asyncio
|
| 62 |
+
async def test_search_includes_country_filter(self, client, sample_anchors):
|
| 63 |
+
"""Search should pass country as a parameter."""
|
| 64 |
+
with patch("httpx.AsyncClient") as MockHTTP:
|
| 65 |
+
mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}})
|
| 66 |
+
|
| 67 |
+
await client.search(sample_anchors)
|
| 68 |
+
|
| 69 |
+
call_args = mock_client.post.call_args
|
| 70 |
+
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 71 |
+
args = body["params"]["arguments"]
|
| 72 |
+
|
| 73 |
+
assert args.get("country") == "United States"
|
| 74 |
+
|
| 75 |
+
@pytest.mark.asyncio
|
| 76 |
+
async def test_search_includes_recruitment_status_filter(self, client, sample_anchors):
|
| 77 |
+
"""Search should include recruitment status in filter expression."""
|
| 78 |
+
with patch("httpx.AsyncClient") as MockHTTP:
|
| 79 |
+
mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}})
|
| 80 |
+
|
| 81 |
+
await client.search(sample_anchors)
|
| 82 |
+
|
| 83 |
+
call_args = mock_client.post.call_args
|
| 84 |
+
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 85 |
+
filter_str = body["params"]["arguments"].get("filter", "")
|
| 86 |
+
|
| 87 |
+
assert "OverallStatus" in filter_str
|
| 88 |
+
assert "Recruiting" in filter_str
|
| 89 |
+
|
| 90 |
+
@pytest.mark.asyncio
|
| 91 |
+
async def test_get_study_by_nct_id(self, client):
|
| 92 |
+
"""Should call get_study tool with correct NCT ID."""
|
| 93 |
+
with patch("httpx.AsyncClient") as MockHTTP:
|
| 94 |
+
self._mock_httpx(MockHTTP, {
|
| 95 |
+
"result": {
|
| 96 |
+
"studies": [{"nctId": "NCT01234567", "title": "Test Trial"}]
|
| 97 |
+
}
|
| 98 |
+
})
|
| 99 |
+
|
| 100 |
+
result = await client.get_study("NCT01234567")
|
| 101 |
+
assert result["nctId"] == "NCT01234567"
|
| 102 |
+
|
| 103 |
+
@pytest.mark.asyncio
|
| 104 |
+
async def test_mcp_error_handling(self, client):
|
| 105 |
+
"""Should raise MCPError on MCP server error response."""
|
| 106 |
+
with patch("httpx.AsyncClient") as MockHTTP:
|
| 107 |
+
self._mock_httpx(MockHTTP, {
|
| 108 |
+
"error": {"code": -32600, "message": "Invalid request"}
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
with pytest.raises(MCPError, match="Invalid request"):
|
| 112 |
+
await client.get_study("NCT00000000")
|
| 113 |
+
|
| 114 |
+
@pytest.mark.asyncio
|
| 115 |
+
async def test_find_eligible_passes_demographics(self, client):
|
| 116 |
+
"""find_eligible should pass patient demographics correctly."""
|
| 117 |
+
with patch("httpx.AsyncClient") as MockHTTP:
|
| 118 |
+
mock_client = self._mock_httpx(MockHTTP, {
|
| 119 |
+
"result": {"eligibleStudies": [], "totalMatches": 0}
|
| 120 |
+
})
|
| 121 |
+
|
| 122 |
+
await client.find_eligible(
|
| 123 |
+
age=52, sex="Female",
|
| 124 |
+
conditions=["NSCLC"],
|
| 125 |
+
country="United States",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
call_args = mock_client.post.call_args
|
| 129 |
+
body = call_args.kwargs.get("json", call_args[1].get("json", {}))
|
| 130 |
+
args = body["params"]["arguments"]
|
| 131 |
+
|
| 132 |
+
assert args["age"] == 52
|
| 133 |
+
assert args["sex"] == "Female"
|
| 134 |
+
assert args["conditions"] == ["NSCLC"]
|
trialpath/tests/test_medgemma.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TDD tests for MedGemma extraction service."""
|
| 2 |
+
import pytest
|
| 3 |
+
|
| 4 |
+
from trialpath.services.medgemma_extractor import MedGemmaExtractor
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TestMedGemmaExtraction:
|
| 8 |
+
"""Test MedGemma extraction pipeline."""
|
| 9 |
+
|
| 10 |
+
def test_parse_valid_json_output(self):
|
| 11 |
+
"""Should parse well-formed JSON from MedGemma."""
|
| 12 |
+
extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
|
| 13 |
+
|
| 14 |
+
raw_output = """
|
| 15 |
+
{
|
| 16 |
+
"patient_id": "P001",
|
| 17 |
+
"diagnosis": {
|
| 18 |
+
"primary_condition": "Non-Small Cell Lung Cancer",
|
| 19 |
+
"histology": "adenocarcinoma",
|
| 20 |
+
"stage": "IVa"
|
| 21 |
+
},
|
| 22 |
+
"performance_status": {
|
| 23 |
+
"scale": "ECOG",
|
| 24 |
+
"value": 1,
|
| 25 |
+
"evidence": [{"doc_id": "clinic_1", "page": 2, "span_id": "s_17"}]
|
| 26 |
+
},
|
| 27 |
+
"biomarkers": [],
|
| 28 |
+
"unknowns": [
|
| 29 |
+
{"field": "EGFR", "reason": "No clear mention", "importance": "high"}
|
| 30 |
+
]
|
| 31 |
+
}
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
result = extractor._parse_profile(raw_output, {"age": 52, "sex": "female"})
|
| 35 |
+
assert result["diagnosis"]["primary_condition"] == "Non-Small Cell Lung Cancer"
|
| 36 |
+
assert result["demographics"]["age"] == 52
|
| 37 |
+
|
| 38 |
+
def test_parse_json_in_code_block(self):
|
| 39 |
+
"""Should extract JSON from markdown code blocks."""
|
| 40 |
+
extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
|
| 41 |
+
|
| 42 |
+
raw_output = """Here is the extracted data:
|
| 43 |
+
```json
|
| 44 |
+
{"patient_id": "P001", "diagnosis": {"primary_condition": "NSCLC", "stage": "IV"}}
|
| 45 |
+
```
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
result = extractor._parse_profile(raw_output, {})
|
| 49 |
+
assert result["diagnosis"]["primary_condition"] == "NSCLC"
|
| 50 |
+
|
| 51 |
+
def test_parse_invalid_output_raises(self):
|
| 52 |
+
"""Should raise ValueError on unparseable output."""
|
| 53 |
+
extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
|
| 54 |
+
|
| 55 |
+
with pytest.raises(ValueError, match="Could not parse"):
|
| 56 |
+
extractor._parse_profile("This is not JSON at all.", {})
|
| 57 |
+
|
| 58 |
+
def test_system_prompt_mentions_oncology(self):
|
| 59 |
+
"""System prompt should reference oncology expertise."""
|
| 60 |
+
extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
|
| 61 |
+
prompt = extractor._system_prompt()
|
| 62 |
+
assert "oncology" in prompt.lower()
|
| 63 |
+
|
| 64 |
+
def test_extraction_prompt_includes_all_fields(self):
|
| 65 |
+
"""Extraction prompt should request all PatientProfile fields."""
|
| 66 |
+
extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
|
| 67 |
+
prompt = extractor._build_extraction_prompt({"age": 52, "sex": "female"})
|
| 68 |
+
|
| 69 |
+
required_fields = [
|
| 70 |
+
"diagnosis", "performance_status", "biomarkers",
|
| 71 |
+
"key_labs", "treatments", "comorbidities",
|
| 72 |
+
"imaging_summary", "unknowns",
|
| 73 |
+
]
|
| 74 |
+
for field in required_fields:
|
| 75 |
+
assert field in prompt
|
| 76 |
+
|
| 77 |
+
def test_extraction_prompt_includes_metadata(self):
|
| 78 |
+
"""Extraction prompt should include provided metadata."""
|
| 79 |
+
extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
|
| 80 |
+
prompt = extractor._build_extraction_prompt({"age": 65, "sex": "male"})
|
| 81 |
+
assert "65" in prompt
|
| 82 |
+
assert "male" in prompt
|