yakilee Claude Opus 4.6 commited on
Commit
a4c0e5c
·
1 Parent(s): 1943883

feat: implement 3 BE service stubs with 17 TDD tests

Browse files

- MedGemmaExtractor: profile parsing, extraction prompts, criterion evaluation
- GeminiPlanner: SearchAnchors generation, eligibility evaluation via google-genai
- ClinicalTrialsMCPClient: MCP JSON-RPC wrapper for search, get_study, find_eligible
- MCPError exception class for error handling

54 total BE tests pass (37 models + 17 services). Ruff clean.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

trialpath/services/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ """TrialPath backend services."""
trialpath/services/gemini_planner.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gemini structured output for SearchAnchors generation and eligibility evaluation."""
2
+ import json
3
+ import os
4
+
5
+ from google import genai
6
+
7
+ from trialpath.models.eligibility_ledger import EligibilityLedger
8
+ from trialpath.models.search_anchors import SearchAnchors
9
+
10
+ MODEL = "gemini-3-pro"
11
+
12
+
13
+ class GeminiPlanner:
14
+ """Orchestration layer using Gemini for structured output."""
15
+
16
+ def __init__(self, model: str = MODEL):
17
+ self.model = model
18
+ self.client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY", ""))
19
+
20
+ async def generate_search_anchors(self, patient_profile: dict) -> SearchAnchors:
21
+ """Use Gemini structured output to generate SearchAnchors from PatientProfile."""
22
+ prompt = f"""
23
+ Given the following patient profile, generate search parameters
24
+ for finding relevant NSCLC clinical trials on ClinicalTrials.gov.
25
+
26
+ Patient Profile:
27
+ {json.dumps(patient_profile, indent=2)}
28
+
29
+ Generate SearchAnchors that:
30
+ 1. Focus on the patient's specific cancer type, stage, and biomarkers
31
+ 2. Include appropriate geographic filters
32
+ 3. Consider the patient's age and performance status
33
+ 4. Set a relaxation_order for broadening search if too few results
34
+ """
35
+
36
+ response = self.client.models.generate_content(
37
+ model=self.model,
38
+ contents=prompt,
39
+ config={
40
+ "response_mime_type": "application/json",
41
+ "response_json_schema": SearchAnchors.model_json_schema(),
42
+ },
43
+ )
44
+
45
+ return SearchAnchors.model_validate_json(response.text)
46
+
47
+ async def evaluate_eligibility(
48
+ self,
49
+ patient_profile: dict,
50
+ trial_candidate: dict,
51
+ search_log: object | None = None,
52
+ ) -> EligibilityLedger:
53
+ """Use Gemini to evaluate eligibility for a single trial."""
54
+ prompt = f"""
55
+ Evaluate this patient's eligibility for the clinical trial below.
56
+
57
+ For each inclusion/exclusion criterion:
58
+ 1. Assign a criterion_id (inc_1, inc_2, ... or exc_1, exc_2, ...)
59
+ 2. Determine if the criterion is met, not_met, or unknown
60
+ 3. Provide reasoning and evidence pointers
61
+
62
+ Patient Profile:
63
+ {json.dumps(patient_profile, indent=2, default=str)}
64
+
65
+ Trial:
66
+ {json.dumps(trial_candidate, indent=2, default=str)}
67
+
68
+ Also identify gaps: criteria that are 'unknown' where additional data
69
+ could change the assessment.
70
+ """
71
+
72
+ response = self.client.models.generate_content(
73
+ model=self.model,
74
+ contents=prompt,
75
+ config={
76
+ "response_mime_type": "application/json",
77
+ "response_json_schema": EligibilityLedger.model_json_schema(),
78
+ },
79
+ )
80
+
81
+ return EligibilityLedger.model_validate_json(response.text)
trialpath/services/mcp_client.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ClinicalTrials MCP server client wrapper."""
2
+ import httpx
3
+
4
+ from trialpath.models.search_anchors import SearchAnchors
5
+
6
+
7
+ class MCPError(Exception):
8
+ """Error returned by the MCP server."""
9
+
10
+ def __init__(self, code: int, message: str):
11
+ self.code = code
12
+ self.message = message
13
+ super().__init__(f"MCP Error {code}: {message}")
14
+
15
+
16
+ class ClinicalTrialsMCPClient:
17
+ """Client for ClinicalTrials MCP Server."""
18
+
19
+ def __init__(self, mcp_url: str = "http://localhost:3000"):
20
+ self.mcp_url = mcp_url
21
+
22
+ async def search(self, anchors: SearchAnchors) -> list[dict]:
23
+ """Convert SearchAnchors to MCP search_studies call."""
24
+ query_parts = [anchors.condition]
25
+ if anchors.subtype:
26
+ query_parts.append(anchors.subtype)
27
+ if anchors.biomarkers:
28
+ query_parts.extend(anchors.biomarkers)
29
+
30
+ query = " ".join(query_parts)
31
+
32
+ filters = []
33
+ if anchors.trial_filters.recruitment_status:
34
+ status_filter = " OR ".join(
35
+ f"AREA[OverallStatus]{s}"
36
+ for s in anchors.trial_filters.recruitment_status
37
+ )
38
+ filters.append(f"({status_filter})")
39
+
40
+ if anchors.trial_filters.phase:
41
+ phase_filter = " OR ".join(
42
+ f"AREA[Phase]{p}" for p in anchors.trial_filters.phase
43
+ )
44
+ filters.append(f"({phase_filter})")
45
+
46
+ if anchors.age is not None:
47
+ filters.append(f"AREA[MinimumAge]RANGE[MIN, {anchors.age}]")
48
+ filters.append(f"AREA[MaximumAge]RANGE[{anchors.age}, MAX]")
49
+
50
+ filter_str = " AND ".join(filters) if filters else None
51
+
52
+ params: dict = {
53
+ "query": query,
54
+ "pageSize": 50,
55
+ "sort": "LastUpdateDate:desc",
56
+ }
57
+ if filter_str:
58
+ params["filter"] = filter_str
59
+ if anchors.geography:
60
+ params["country"] = anchors.geography.country
61
+
62
+ result = await self._call_tool("clinicaltrials_search_studies", params)
63
+ return result.get("studies", [])
64
+
65
+ async def get_study(self, nct_id: str) -> dict:
66
+ """Fetch full study details by NCT ID."""
67
+ result = await self._call_tool("clinicaltrials_get_study", {
68
+ "nctIds": [nct_id],
69
+ "summaryOnly": False,
70
+ })
71
+ studies = result.get("studies", [])
72
+ return studies[0] if studies else {}
73
+
74
+ async def find_eligible(
75
+ self,
76
+ age: int,
77
+ sex: str,
78
+ conditions: list[str],
79
+ country: str,
80
+ max_results: int = 20,
81
+ ) -> dict:
82
+ """Use find_eligible_studies for demographic-based matching."""
83
+ return await self._call_tool("clinicaltrials_find_eligible_studies", {
84
+ "age": age,
85
+ "sex": sex,
86
+ "conditions": conditions,
87
+ "location": {"country": country},
88
+ "recruitingOnly": True,
89
+ "maxResults": max_results,
90
+ })
91
+
92
+ async def compare_studies(self, nct_ids: list[str]) -> dict:
93
+ """Compare 2-5 studies side by side."""
94
+ return await self._call_tool("clinicaltrials_compare_studies", {
95
+ "nctIds": nct_ids,
96
+ "compareFields": "all",
97
+ })
98
+
99
+ async def _call_tool(self, tool_name: str, params: dict) -> dict:
100
+ """Call an MCP tool via JSON-RPC."""
101
+ async with httpx.AsyncClient(timeout=30.0) as client:
102
+ response = await client.post(
103
+ f"{self.mcp_url}/mcp/v1/tools/call",
104
+ json={
105
+ "jsonrpc": "2.0",
106
+ "method": "tools/call",
107
+ "params": {
108
+ "name": tool_name,
109
+ "arguments": params,
110
+ },
111
+ "id": 1,
112
+ },
113
+ )
114
+ response.raise_for_status()
115
+ data = response.json()
116
+
117
+ if "error" in data:
118
+ raise MCPError(
119
+ code=data["error"].get("code", -1),
120
+ message=data["error"].get("message", "Unknown MCP error"),
121
+ )
122
+
123
+ return data.get("result", {})
trialpath/services/medgemma_extractor.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MedGemma HF endpoint integration for patient profile extraction."""
2
+ import json
3
+ import re
4
+
5
+
6
+ class MedGemmaExtractor:
7
+ """Extract patient profiles from medical documents using MedGemma.
8
+
9
+ For PoC, this is interface-first: the parsing and prompt methods are
10
+ fully implemented, but the actual model call requires a HuggingFace
11
+ Inference Endpoint or local GPU with MedGemma loaded.
12
+ """
13
+
14
+ def __init__(self, endpoint_url: str | None = None, hf_token: str | None = None):
15
+ self.endpoint_url = endpoint_url
16
+ self.hf_token = hf_token
17
+ self.pipe = None # Initialized lazily when model is available
18
+
19
+ def _system_prompt(self) -> str:
20
+ return (
21
+ "You are an expert medical data extractor specializing in oncology. "
22
+ "Extract structured patient information from medical documents. "
23
+ "Always cite the source document and location for each extracted fact. "
24
+ "If information is unclear or missing, explicitly note it as unknown."
25
+ )
26
+
27
+ def _build_extraction_prompt(self, metadata: dict) -> str:
28
+ return f"""
29
+ Extract a structured patient profile from the following medical documents.
30
+
31
+ Known metadata: age={metadata.get('age', 'unknown')}, sex={metadata.get('sex', 'unknown')}
32
+
33
+ Extract the following fields in JSON format:
34
+ - diagnosis (primary_condition, histology, stage, diagnosis_date)
35
+ - performance_status (scale, value, evidence)
36
+ - biomarkers (name, result, date, evidence for each)
37
+ - key_labs (name, value, unit, date, evidence for each)
38
+ - treatments (drug_name, start_date, end_date, line, evidence)
39
+ - comorbidities (name, grade, evidence)
40
+ - imaging_summary (modality, date, finding, interpretation, certainty, evidence)
41
+ - unknowns (field, reason, importance for each missing critical field)
42
+
43
+ For each evidence reference, include: doc_id (filename), page number, span_id.
44
+
45
+ Return ONLY valid JSON matching the PatientProfile schema.
46
+ """
47
+
48
+ def _parse_profile(self, raw_text: str, metadata: dict) -> dict:
49
+ """Parse MedGemma output into PatientProfile structure."""
50
+ try:
51
+ profile = json.loads(raw_text)
52
+ except json.JSONDecodeError:
53
+ json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", raw_text, re.DOTALL)
54
+ if json_match:
55
+ profile = json.loads(json_match.group(1))
56
+ else:
57
+ raise ValueError(
58
+ f"Could not parse MedGemma output as JSON: {raw_text[:200]}"
59
+ )
60
+
61
+ if "demographics" not in profile:
62
+ profile["demographics"] = {}
63
+ profile["demographics"].update(metadata)
64
+
65
+ return profile
66
+
67
+ async def extract(self, document_urls: list[str], metadata: dict) -> dict:
68
+ """Extract PatientProfile from documents via MedGemma.
69
+
70
+ Requires self.pipe to be initialized with a HuggingFace pipeline.
71
+ """
72
+ if self.pipe is None:
73
+ raise RuntimeError(
74
+ "MedGemma pipeline not initialized. "
75
+ "Set up a HF Inference Endpoint or load model locally."
76
+ )
77
+
78
+ content = [
79
+ {"type": "text", "text": self._build_extraction_prompt(metadata)},
80
+ ]
81
+ # In production, images would be loaded from document_urls
82
+ messages = [
83
+ {"role": "system", "content": [{"type": "text", "text": self._system_prompt()}]},
84
+ {"role": "user", "content": content},
85
+ ]
86
+
87
+ output = self.pipe(text=messages, max_new_tokens=2048)
88
+ raw_text = output[0]["generated_text"][-1]["content"]
89
+ return self._parse_profile(raw_text, metadata)
90
+
91
+ async def evaluate_medical_criterion(
92
+ self,
93
+ criterion_text: str,
94
+ patient_profile: object,
95
+ evidence_docs: list,
96
+ ) -> dict:
97
+ """Evaluate a single medical criterion against patient evidence.
98
+
99
+ Stub for PoC -- requires MedGemma pipeline.
100
+ """
101
+ if self.pipe is None:
102
+ raise RuntimeError("MedGemma pipeline not initialized.")
103
+
104
+ prompt = f"""
105
+ Evaluate whether the patient meets this clinical trial criterion:
106
+ CRITERION: {criterion_text}
107
+
108
+ Respond with JSON:
109
+ {{"decision": "met|not_met|unknown", "reasoning": "...", "confidence": 0.0-1.0}}
110
+ """
111
+
112
+ messages = [
113
+ {"role": "system", "content": [{"type": "text", "text": self._system_prompt()}]},
114
+ {"role": "user", "content": [{"type": "text", "text": prompt}]},
115
+ ]
116
+
117
+ output = self.pipe(text=messages, max_new_tokens=1024)
118
+ raw_text = output[0]["generated_text"][-1]["content"]
119
+ return json.loads(raw_text)
trialpath/tests/test_gemini.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TDD tests for Gemini planner service."""
2
+ from unittest.mock import MagicMock, patch
3
+
4
+ import pytest
5
+
6
+ from trialpath.models.eligibility_ledger import EligibilityLedger, OverallAssessment
7
+ from trialpath.models.search_anchors import SearchAnchors
8
+ from trialpath.services.gemini_planner import GeminiPlanner
9
+
10
+
11
+ class TestGeminiSearchAnchorsGeneration:
12
+ """Test Gemini structured output for SearchAnchors generation."""
13
+
14
+ @pytest.fixture
15
+ def sample_profile(self):
16
+ return {
17
+ "patient_id": "P001",
18
+ "demographics": {"age": 52, "sex": "female"},
19
+ "diagnosis": {
20
+ "primary_condition": "Non-Small Cell Lung Cancer",
21
+ "histology": "adenocarcinoma",
22
+ "stage": "IVa",
23
+ },
24
+ "biomarkers": [
25
+ {"name": "EGFR", "result": "Exon 19 deletion"},
26
+ ],
27
+ "performance_status": {"scale": "ECOG", "value": 1},
28
+ }
29
+
30
+ @pytest.mark.asyncio
31
+ async def test_search_anchors_has_correct_condition(self, sample_profile):
32
+ """Generated SearchAnchors should reference NSCLC."""
33
+ with patch("google.genai.Client") as MockClient:
34
+ mock_response = MagicMock()
35
+ mock_response.text = SearchAnchors(
36
+ condition="Non-Small Cell Lung Cancer",
37
+ subtype="adenocarcinoma",
38
+ biomarkers=["EGFR exon 19 deletion"],
39
+ stage="IV",
40
+ age=52,
41
+ performance_status_max=1,
42
+ ).model_dump_json()
43
+
44
+ MockClient.return_value.models.generate_content = MagicMock(
45
+ return_value=mock_response
46
+ )
47
+
48
+ planner = GeminiPlanner()
49
+ anchors = await planner.generate_search_anchors(sample_profile)
50
+
51
+ assert "lung" in anchors.condition.lower() or "nsclc" in anchors.condition.lower()
52
+ assert anchors.age == 52
53
+
54
+ @pytest.mark.asyncio
55
+ async def test_search_anchors_includes_biomarkers(self, sample_profile):
56
+ """SearchAnchors should include patient biomarkers."""
57
+ with patch("google.genai.Client") as MockClient:
58
+ mock_response = MagicMock()
59
+ mock_response.text = SearchAnchors(
60
+ condition="NSCLC",
61
+ biomarkers=["EGFR exon 19 deletion"],
62
+ ).model_dump_json()
63
+
64
+ MockClient.return_value.models.generate_content = MagicMock(
65
+ return_value=mock_response
66
+ )
67
+
68
+ planner = GeminiPlanner()
69
+ anchors = await planner.generate_search_anchors(sample_profile)
70
+
71
+ assert len(anchors.biomarkers) > 0
72
+ assert any("EGFR" in b for b in anchors.biomarkers)
73
+
74
+ @pytest.mark.asyncio
75
+ async def test_search_anchors_json_schema_passed(self, sample_profile):
76
+ """Verify that Gemini is called with response_json_schema."""
77
+ with patch("google.genai.Client") as MockClient:
78
+ mock_response = MagicMock()
79
+ mock_response.text = SearchAnchors(condition="NSCLC").model_dump_json()
80
+
81
+ mock_generate = MagicMock(return_value=mock_response)
82
+ MockClient.return_value.models.generate_content = mock_generate
83
+
84
+ planner = GeminiPlanner()
85
+ await planner.generate_search_anchors(sample_profile)
86
+
87
+ call_args = mock_generate.call_args
88
+ config = call_args.kwargs.get("config", call_args[1].get("config", {}))
89
+ assert config.get("response_mime_type") == "application/json"
90
+ assert "response_json_schema" in config
91
+
92
+
93
+ class TestGeminiEligibilityEvaluation:
94
+ """Test Gemini eligibility evaluation output."""
95
+
96
+ @pytest.mark.asyncio
97
+ async def test_ledger_has_all_required_fields(self):
98
+ """EligibilityLedger from Gemini should have patient_id, nct_id, assessment."""
99
+ mock_ledger = EligibilityLedger(
100
+ patient_id="P001",
101
+ nct_id="NCT01234567",
102
+ overall_assessment=OverallAssessment.UNCERTAIN,
103
+ criteria=[],
104
+ gaps=[],
105
+ )
106
+
107
+ assert mock_ledger.patient_id == "P001"
108
+ assert mock_ledger.nct_id == "NCT01234567"
109
+ assert mock_ledger.overall_assessment in OverallAssessment
110
+
111
+ @pytest.mark.asyncio
112
+ async def test_error_handling_invalid_json(self):
113
+ """Should raise error on invalid Gemini JSON response."""
114
+ with patch("google.genai.Client") as MockClient:
115
+ mock_response = MagicMock()
116
+ mock_response.text = "not valid json"
117
+
118
+ MockClient.return_value.models.generate_content = MagicMock(
119
+ return_value=mock_response
120
+ )
121
+
122
+ planner = GeminiPlanner()
123
+ with pytest.raises(Exception):
124
+ await planner.evaluate_eligibility({}, {}, None)
trialpath/tests/test_mcp.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TDD tests for ClinicalTrials MCP client."""
2
+ from unittest.mock import AsyncMock, MagicMock, patch
3
+
4
+ import pytest
5
+
6
+ from trialpath.models.search_anchors import GeographyFilter, SearchAnchors, TrialFilters
7
+ from trialpath.services.mcp_client import ClinicalTrialsMCPClient, MCPError
8
+
9
+
10
+ class TestMCPClient:
11
+ """Test ClinicalTrials MCP client."""
12
+
13
+ @pytest.fixture
14
+ def client(self):
15
+ return ClinicalTrialsMCPClient(mcp_url="http://localhost:3000")
16
+
17
+ @pytest.fixture
18
+ def sample_anchors(self):
19
+ return SearchAnchors(
20
+ condition="Non-Small Cell Lung Cancer",
21
+ subtype="adenocarcinoma",
22
+ biomarkers=["EGFR exon 19 deletion"],
23
+ stage="IV",
24
+ age=52,
25
+ geography=GeographyFilter(country="United States"),
26
+ trial_filters=TrialFilters(
27
+ recruitment_status=["Recruiting"],
28
+ phase=["Phase 3"],
29
+ ),
30
+ )
31
+
32
+ def _mock_httpx(self, MockHTTP, response_data):
33
+ mock_response = MagicMock()
34
+ mock_response.json.return_value = response_data
35
+ mock_response.raise_for_status = MagicMock()
36
+
37
+ mock_client = AsyncMock()
38
+ mock_client.post.return_value = mock_response
39
+
40
+ mock_ctx = MagicMock()
41
+ mock_ctx.__aenter__ = AsyncMock(return_value=mock_client)
42
+ mock_ctx.__aexit__ = AsyncMock(return_value=None)
43
+ MockHTTP.return_value = mock_ctx
44
+ return mock_client
45
+
46
+ @pytest.mark.asyncio
47
+ async def test_search_builds_correct_query(self, client, sample_anchors):
48
+ """Search should combine condition, subtype, and biomarkers into query."""
49
+ with patch("httpx.AsyncClient") as MockHTTP:
50
+ mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}})
51
+
52
+ await client.search(sample_anchors)
53
+
54
+ call_args = mock_client.post.call_args
55
+ body = call_args.kwargs.get("json", call_args[1].get("json", {}))
56
+ query = body["params"]["arguments"]["query"]
57
+
58
+ assert "Non-Small Cell Lung Cancer" in query
59
+ assert "adenocarcinoma" in query
60
+
61
+ @pytest.mark.asyncio
62
+ async def test_search_includes_country_filter(self, client, sample_anchors):
63
+ """Search should pass country as a parameter."""
64
+ with patch("httpx.AsyncClient") as MockHTTP:
65
+ mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}})
66
+
67
+ await client.search(sample_anchors)
68
+
69
+ call_args = mock_client.post.call_args
70
+ body = call_args.kwargs.get("json", call_args[1].get("json", {}))
71
+ args = body["params"]["arguments"]
72
+
73
+ assert args.get("country") == "United States"
74
+
75
+ @pytest.mark.asyncio
76
+ async def test_search_includes_recruitment_status_filter(self, client, sample_anchors):
77
+ """Search should include recruitment status in filter expression."""
78
+ with patch("httpx.AsyncClient") as MockHTTP:
79
+ mock_client = self._mock_httpx(MockHTTP, {"result": {"studies": []}})
80
+
81
+ await client.search(sample_anchors)
82
+
83
+ call_args = mock_client.post.call_args
84
+ body = call_args.kwargs.get("json", call_args[1].get("json", {}))
85
+ filter_str = body["params"]["arguments"].get("filter", "")
86
+
87
+ assert "OverallStatus" in filter_str
88
+ assert "Recruiting" in filter_str
89
+
90
+ @pytest.mark.asyncio
91
+ async def test_get_study_by_nct_id(self, client):
92
+ """Should call get_study tool with correct NCT ID."""
93
+ with patch("httpx.AsyncClient") as MockHTTP:
94
+ self._mock_httpx(MockHTTP, {
95
+ "result": {
96
+ "studies": [{"nctId": "NCT01234567", "title": "Test Trial"}]
97
+ }
98
+ })
99
+
100
+ result = await client.get_study("NCT01234567")
101
+ assert result["nctId"] == "NCT01234567"
102
+
103
+ @pytest.mark.asyncio
104
+ async def test_mcp_error_handling(self, client):
105
+ """Should raise MCPError on MCP server error response."""
106
+ with patch("httpx.AsyncClient") as MockHTTP:
107
+ self._mock_httpx(MockHTTP, {
108
+ "error": {"code": -32600, "message": "Invalid request"}
109
+ })
110
+
111
+ with pytest.raises(MCPError, match="Invalid request"):
112
+ await client.get_study("NCT00000000")
113
+
114
+ @pytest.mark.asyncio
115
+ async def test_find_eligible_passes_demographics(self, client):
116
+ """find_eligible should pass patient demographics correctly."""
117
+ with patch("httpx.AsyncClient") as MockHTTP:
118
+ mock_client = self._mock_httpx(MockHTTP, {
119
+ "result": {"eligibleStudies": [], "totalMatches": 0}
120
+ })
121
+
122
+ await client.find_eligible(
123
+ age=52, sex="Female",
124
+ conditions=["NSCLC"],
125
+ country="United States",
126
+ )
127
+
128
+ call_args = mock_client.post.call_args
129
+ body = call_args.kwargs.get("json", call_args[1].get("json", {}))
130
+ args = body["params"]["arguments"]
131
+
132
+ assert args["age"] == 52
133
+ assert args["sex"] == "Female"
134
+ assert args["conditions"] == ["NSCLC"]
trialpath/tests/test_medgemma.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TDD tests for MedGemma extraction service."""
2
+ import pytest
3
+
4
+ from trialpath.services.medgemma_extractor import MedGemmaExtractor
5
+
6
+
7
+ class TestMedGemmaExtraction:
8
+ """Test MedGemma extraction pipeline."""
9
+
10
+ def test_parse_valid_json_output(self):
11
+ """Should parse well-formed JSON from MedGemma."""
12
+ extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
13
+
14
+ raw_output = """
15
+ {
16
+ "patient_id": "P001",
17
+ "diagnosis": {
18
+ "primary_condition": "Non-Small Cell Lung Cancer",
19
+ "histology": "adenocarcinoma",
20
+ "stage": "IVa"
21
+ },
22
+ "performance_status": {
23
+ "scale": "ECOG",
24
+ "value": 1,
25
+ "evidence": [{"doc_id": "clinic_1", "page": 2, "span_id": "s_17"}]
26
+ },
27
+ "biomarkers": [],
28
+ "unknowns": [
29
+ {"field": "EGFR", "reason": "No clear mention", "importance": "high"}
30
+ ]
31
+ }
32
+ """
33
+
34
+ result = extractor._parse_profile(raw_output, {"age": 52, "sex": "female"})
35
+ assert result["diagnosis"]["primary_condition"] == "Non-Small Cell Lung Cancer"
36
+ assert result["demographics"]["age"] == 52
37
+
38
+ def test_parse_json_in_code_block(self):
39
+ """Should extract JSON from markdown code blocks."""
40
+ extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
41
+
42
+ raw_output = """Here is the extracted data:
43
+ ```json
44
+ {"patient_id": "P001", "diagnosis": {"primary_condition": "NSCLC", "stage": "IV"}}
45
+ ```
46
+ """
47
+
48
+ result = extractor._parse_profile(raw_output, {})
49
+ assert result["diagnosis"]["primary_condition"] == "NSCLC"
50
+
51
+ def test_parse_invalid_output_raises(self):
52
+ """Should raise ValueError on unparseable output."""
53
+ extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
54
+
55
+ with pytest.raises(ValueError, match="Could not parse"):
56
+ extractor._parse_profile("This is not JSON at all.", {})
57
+
58
+ def test_system_prompt_mentions_oncology(self):
59
+ """System prompt should reference oncology expertise."""
60
+ extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
61
+ prompt = extractor._system_prompt()
62
+ assert "oncology" in prompt.lower()
63
+
64
+ def test_extraction_prompt_includes_all_fields(self):
65
+ """Extraction prompt should request all PatientProfile fields."""
66
+ extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
67
+ prompt = extractor._build_extraction_prompt({"age": 52, "sex": "female"})
68
+
69
+ required_fields = [
70
+ "diagnosis", "performance_status", "biomarkers",
71
+ "key_labs", "treatments", "comorbidities",
72
+ "imaging_summary", "unknowns",
73
+ ]
74
+ for field in required_fields:
75
+ assert field in prompt
76
+
77
+ def test_extraction_prompt_includes_metadata(self):
78
+ """Extraction prompt should include provided metadata."""
79
+ extractor = MedGemmaExtractor.__new__(MedGemmaExtractor)
80
+ prompt = extractor._build_extraction_prompt({"age": 65, "sex": "male"})
81
+ assert "65" in prompt
82
+ assert "male" in prompt