yakilee Claude Opus 4.6 commited on
Commit
601f310
·
1 Parent(s): 4b8585c

feat: implement 7 Parlant tools

Browse files

Add @tool decorated functions: extract_patient_profile,
generate_search_anchors, search_clinical_trials, refine_search_query,
relax_search_query, evaluate_trial_eligibility (dual-model), and
analyze_gaps. Each returns ToolResult with data and metadata.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

trialpath/agent/tools.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parlant tool definitions for the TrialPath agent."""
2
+ import json
3
+
4
+ from parlant.sdk import ToolContext, ToolResult, tool
5
+
6
+ from trialpath.config import (
7
+ GEMINI_API_KEY,
8
+ GEMINI_MODEL,
9
+ HF_TOKEN,
10
+ MCP_URL,
11
+ MEDGEMMA_ENDPOINT_URL,
12
+ )
13
+
14
+
15
+ @tool
16
+ async def extract_patient_profile(
17
+ context: ToolContext,
18
+ document_urls: str,
19
+ metadata: str,
20
+ ) -> ToolResult:
21
+ """Extract a structured patient profile from uploaded medical documents.
22
+
23
+ Args:
24
+ context: Parlant tool context.
25
+ document_urls: JSON list of document file paths.
26
+ metadata: JSON object with known patient metadata (age, sex).
27
+ """
28
+ from trialpath.services.medgemma_extractor import MedGemmaExtractor
29
+
30
+ extractor = MedGemmaExtractor(
31
+ endpoint_url=MEDGEMMA_ENDPOINT_URL,
32
+ hf_token=HF_TOKEN,
33
+ )
34
+ urls = json.loads(document_urls)
35
+ meta = json.loads(metadata)
36
+ profile = await extractor.extract(urls, meta)
37
+
38
+ return ToolResult(
39
+ data=profile,
40
+ metadata={"source": "medgemma", "doc_count": len(urls)},
41
+ )
42
+
43
+
44
+ @tool
45
+ async def generate_search_anchors(
46
+ context: ToolContext,
47
+ patient_profile: str,
48
+ ) -> ToolResult:
49
+ """Generate search parameters from a patient profile for ClinicalTrials.gov.
50
+
51
+ Args:
52
+ context: Parlant tool context.
53
+ patient_profile: JSON string of PatientProfile data.
54
+ """
55
+ from trialpath.services.gemini_planner import GeminiPlanner
56
+
57
+ planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
58
+ profile = json.loads(patient_profile)
59
+ anchors = await planner.generate_search_anchors(profile)
60
+
61
+ return ToolResult(
62
+ data=anchors.model_dump(),
63
+ metadata={"source": "gemini"},
64
+ )
65
+
66
+
67
+ @tool
68
+ async def search_clinical_trials(
69
+ context: ToolContext,
70
+ search_anchors: str,
71
+ ) -> ToolResult:
72
+ """Search ClinicalTrials.gov for matching trials using search anchors.
73
+
74
+ Args:
75
+ context: Parlant tool context.
76
+ search_anchors: JSON string of SearchAnchors data.
77
+ """
78
+ from trialpath.models.search_anchors import SearchAnchors
79
+ from trialpath.services.mcp_client import ClinicalTrialsMCPClient
80
+
81
+ client = ClinicalTrialsMCPClient(mcp_url=MCP_URL)
82
+ anchors = SearchAnchors.model_validate(json.loads(search_anchors))
83
+ raw_studies = await client.search(anchors)
84
+
85
+ trials = [
86
+ ClinicalTrialsMCPClient.normalize_trial(s).model_dump()
87
+ for s in raw_studies
88
+ ]
89
+
90
+ return ToolResult(
91
+ data={"trials": trials, "count": len(trials)},
92
+ metadata={"source": "clinicaltrials_mcp"},
93
+ )
94
+
95
+
96
+ @tool
97
+ async def refine_search_query(
98
+ context: ToolContext,
99
+ search_anchors: str,
100
+ result_count: str,
101
+ ) -> ToolResult:
102
+ """Refine search parameters when too many results returned.
103
+
104
+ Args:
105
+ context: Parlant tool context.
106
+ search_anchors: JSON string of current SearchAnchors.
107
+ result_count: Number of results from last search.
108
+ """
109
+ from trialpath.models.search_anchors import SearchAnchors
110
+ from trialpath.services.gemini_planner import GeminiPlanner
111
+
112
+ planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
113
+ anchors = SearchAnchors.model_validate(json.loads(search_anchors))
114
+ refined = await planner.refine_search(anchors, int(result_count))
115
+
116
+ return ToolResult(
117
+ data=refined.model_dump(),
118
+ metadata={"action": "refine", "prev_count": int(result_count)},
119
+ )
120
+
121
+
122
+ @tool
123
+ async def relax_search_query(
124
+ context: ToolContext,
125
+ search_anchors: str,
126
+ result_count: str,
127
+ ) -> ToolResult:
128
+ """Relax search parameters when too few results returned.
129
+
130
+ Args:
131
+ context: Parlant tool context.
132
+ search_anchors: JSON string of current SearchAnchors.
133
+ result_count: Number of results from last search.
134
+ """
135
+ from trialpath.models.search_anchors import SearchAnchors
136
+ from trialpath.services.gemini_planner import GeminiPlanner
137
+
138
+ planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
139
+ anchors = SearchAnchors.model_validate(json.loads(search_anchors))
140
+ relaxed = await planner.relax_search(anchors, int(result_count))
141
+
142
+ return ToolResult(
143
+ data=relaxed.model_dump(),
144
+ metadata={"action": "relax", "prev_count": int(result_count)},
145
+ )
146
+
147
+
148
+ @tool
149
+ async def evaluate_trial_eligibility(
150
+ context: ToolContext,
151
+ patient_profile: str,
152
+ trial_candidate: str,
153
+ ) -> ToolResult:
154
+ """Evaluate patient eligibility for a clinical trial using dual-model approach.
155
+
156
+ Medical criteria evaluated by MedGemma, structural by Gemini.
157
+
158
+ Args:
159
+ context: Parlant tool context.
160
+ patient_profile: JSON string of PatientProfile data.
161
+ trial_candidate: JSON string of TrialCandidate data.
162
+ """
163
+ from trialpath.services.gemini_planner import GeminiPlanner
164
+ from trialpath.services.medgemma_extractor import MedGemmaExtractor
165
+
166
+ profile = json.loads(patient_profile)
167
+ trial = json.loads(trial_candidate)
168
+
169
+ planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
170
+ extractor = MedGemmaExtractor(
171
+ endpoint_url=MEDGEMMA_ENDPOINT_URL,
172
+ hf_token=HF_TOKEN,
173
+ )
174
+
175
+ # Step 1: Slice criteria into atomic items
176
+ criteria = await planner.slice_criteria(trial)
177
+
178
+ # Step 2: Evaluate each criterion with appropriate model
179
+ assessments = []
180
+ for criterion in criteria:
181
+ if criterion.get("category") == "medical":
182
+ result = await extractor.evaluate_medical_criterion(
183
+ criterion["text"], profile, []
184
+ )
185
+ else:
186
+ result = await planner.evaluate_structural_criterion(
187
+ criterion["text"], profile
188
+ )
189
+ assessments.append({**criterion, **result})
190
+
191
+ # Step 3: Aggregate into overall assessment
192
+ ledger = await planner.aggregate_assessments(profile, trial, assessments)
193
+
194
+ return ToolResult(
195
+ data=ledger.model_dump(),
196
+ metadata={"source": "dual_model", "criteria_count": len(criteria)},
197
+ )
198
+
199
+
200
+ @tool
201
+ async def analyze_gaps(
202
+ context: ToolContext,
203
+ patient_profile: str,
204
+ eligibility_ledgers: str,
205
+ ) -> ToolResult:
206
+ """Analyze eligibility gaps across all evaluated trials.
207
+
208
+ Args:
209
+ context: Parlant tool context.
210
+ patient_profile: JSON string of PatientProfile data.
211
+ eligibility_ledgers: JSON list of EligibilityLedger data.
212
+ """
213
+ from trialpath.services.gemini_planner import GeminiPlanner
214
+
215
+ planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
216
+ profile = json.loads(patient_profile)
217
+ ledgers = json.loads(eligibility_ledgers)
218
+ gaps = await planner.analyze_gaps(profile, ledgers)
219
+
220
+ return ToolResult(
221
+ data={"gaps": gaps, "count": len(gaps)},
222
+ metadata={"source": "gemini"},
223
+ )
224
+
225
+
226
+ ALL_TOOLS = [
227
+ extract_patient_profile,
228
+ generate_search_anchors,
229
+ search_clinical_trials,
230
+ refine_search_query,
231
+ relax_search_query,
232
+ evaluate_trial_eligibility,
233
+ analyze_gaps,
234
+ ]
trialpath/tests/test_tools.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TDD tests for Parlant tool functions."""
2
+ import json
3
+ from unittest.mock import AsyncMock, MagicMock, patch
4
+
5
+ import pytest
6
+
7
+ from trialpath.agent.tools import (
8
+ ALL_TOOLS,
9
+ analyze_gaps,
10
+ evaluate_trial_eligibility,
11
+ extract_patient_profile,
12
+ generate_search_anchors,
13
+ refine_search_query,
14
+ relax_search_query,
15
+ search_clinical_trials,
16
+ )
17
+
18
+
19
+ @pytest.fixture
20
+ def mock_context():
21
+ return MagicMock()
22
+
23
+
24
+ class TestExtractPatientProfile:
25
+ """Test extract_patient_profile tool."""
26
+
27
+ @pytest.mark.asyncio
28
+ async def test_calls_medgemma_extractor(self, mock_context):
29
+ """Should call MedGemmaExtractor.extract with correct args."""
30
+ profile = {"patient_id": "P001", "diagnosis": {"primary_condition": "NSCLC"}}
31
+
32
+ with patch(
33
+ "trialpath.services.medgemma_extractor.MedGemmaExtractor"
34
+ ) as MockExtractor:
35
+ MockExtractor.return_value.extract = AsyncMock(return_value=profile)
36
+
37
+ result = await extract_patient_profile.function(
38
+ mock_context,
39
+ document_urls=json.dumps(["doc1.pdf"]),
40
+ metadata=json.dumps({"age": 52}),
41
+ )
42
+
43
+ MockExtractor.return_value.extract.assert_called_once()
44
+ assert result.data["patient_id"] == "P001"
45
+
46
+ @pytest.mark.asyncio
47
+ async def test_returns_tool_result_with_metadata(self, mock_context):
48
+ """ToolResult should contain source metadata."""
49
+ with patch(
50
+ "trialpath.services.medgemma_extractor.MedGemmaExtractor"
51
+ ) as MockExtractor:
52
+ MockExtractor.return_value.extract = AsyncMock(return_value={})
53
+
54
+ result = await extract_patient_profile.function(
55
+ mock_context,
56
+ document_urls=json.dumps(["a.pdf", "b.pdf"]),
57
+ metadata=json.dumps({}),
58
+ )
59
+
60
+ assert result.metadata["source"] == "medgemma"
61
+ assert result.metadata["doc_count"] == 2
62
+
63
+
64
+ class TestGenerateSearchAnchors:
65
+ """Test generate_search_anchors tool."""
66
+
67
+ @pytest.mark.asyncio
68
+ async def test_calls_gemini_planner(self, mock_context):
69
+ """Should call GeminiPlanner.generate_search_anchors."""
70
+ from trialpath.models.search_anchors import SearchAnchors
71
+
72
+ mock_anchors = SearchAnchors(condition="NSCLC")
73
+
74
+ with patch(
75
+ "trialpath.services.gemini_planner.GeminiPlanner"
76
+ ) as MockPlanner:
77
+ MockPlanner.return_value.generate_search_anchors = AsyncMock(
78
+ return_value=mock_anchors
79
+ )
80
+
81
+ result = await generate_search_anchors.function(
82
+ mock_context,
83
+ patient_profile=json.dumps({"patient_id": "P001"}),
84
+ )
85
+
86
+ assert result.data["condition"] == "NSCLC"
87
+
88
+
89
+ class TestSearchClinicalTrials:
90
+ """Test search_clinical_trials tool."""
91
+
92
+ @pytest.mark.asyncio
93
+ async def test_calls_mcp_client_and_normalizes(self, mock_context):
94
+ """Should call MCP client and normalize results."""
95
+ raw_study = {"nctId": "NCT001", "title": "Test Trial"}
96
+
97
+ with patch(
98
+ "trialpath.services.mcp_client.ClinicalTrialsMCPClient"
99
+ ) as MockClient:
100
+ MockClient.return_value.search = AsyncMock(return_value=[raw_study])
101
+ mock_trial = MagicMock()
102
+ mock_trial.model_dump.return_value = {
103
+ "nct_id": "NCT001", "title": "Test Trial"
104
+ }
105
+ MockClient.normalize_trial = MagicMock(return_value=mock_trial)
106
+
107
+ result = await search_clinical_trials.function(
108
+ mock_context,
109
+ search_anchors=json.dumps({"condition": "NSCLC"}),
110
+ )
111
+
112
+ assert result.data["count"] == 1
113
+ assert result.metadata["source"] == "clinicaltrials_mcp"
114
+
115
+
116
+ class TestRefineSearchQuery:
117
+ """Test refine_search_query tool."""
118
+
119
+ @pytest.mark.asyncio
120
+ async def test_calls_gemini_refine(self, mock_context):
121
+ """Should call GeminiPlanner.refine_search."""
122
+ from trialpath.models.search_anchors import SearchAnchors
123
+
124
+ mock_refined = SearchAnchors(condition="NSCLC", biomarkers=["EGFR"])
125
+
126
+ with patch(
127
+ "trialpath.services.gemini_planner.GeminiPlanner"
128
+ ) as MockPlanner:
129
+ MockPlanner.return_value.refine_search = AsyncMock(
130
+ return_value=mock_refined
131
+ )
132
+
133
+ result = await refine_search_query.function(
134
+ mock_context,
135
+ search_anchors=json.dumps({"condition": "NSCLC"}),
136
+ result_count="100",
137
+ )
138
+
139
+ assert result.metadata["action"] == "refine"
140
+ assert result.metadata["prev_count"] == 100
141
+
142
+
143
+ class TestRelaxSearchQuery:
144
+ """Test relax_search_query tool."""
145
+
146
+ @pytest.mark.asyncio
147
+ async def test_calls_gemini_relax(self, mock_context):
148
+ """Should call GeminiPlanner.relax_search."""
149
+ from trialpath.models.search_anchors import SearchAnchors
150
+
151
+ mock_relaxed = SearchAnchors(condition="NSCLC")
152
+
153
+ with patch(
154
+ "trialpath.services.gemini_planner.GeminiPlanner"
155
+ ) as MockPlanner:
156
+ MockPlanner.return_value.relax_search = AsyncMock(
157
+ return_value=mock_relaxed
158
+ )
159
+
160
+ result = await relax_search_query.function(
161
+ mock_context,
162
+ search_anchors=json.dumps({"condition": "NSCLC"}),
163
+ result_count="0",
164
+ )
165
+
166
+ assert result.metadata["action"] == "relax"
167
+
168
+
169
+ class TestEvaluateTrialEligibility:
170
+ """Test evaluate_trial_eligibility tool."""
171
+
172
+ @pytest.mark.asyncio
173
+ async def test_dual_model_evaluation(self, mock_context):
174
+ """Should use MedGemma for medical and Gemini for structural criteria."""
175
+ from trialpath.models.eligibility_ledger import (
176
+ EligibilityLedger,
177
+ OverallAssessment,
178
+ )
179
+
180
+ mock_ledger = EligibilityLedger(
181
+ patient_id="P001",
182
+ nct_id="NCT001",
183
+ overall_assessment=OverallAssessment.LIKELY_ELIGIBLE,
184
+ criteria=[],
185
+ gaps=[],
186
+ )
187
+
188
+ with (
189
+ patch(
190
+ "trialpath.services.gemini_planner.GeminiPlanner"
191
+ ) as MockPlanner,
192
+ patch(
193
+ "trialpath.services.medgemma_extractor.MedGemmaExtractor"
194
+ ) as MockExtractor,
195
+ ):
196
+ MockPlanner.return_value.slice_criteria = AsyncMock(
197
+ return_value=[
198
+ {
199
+ "criterion_id": "inc_1",
200
+ "type": "inclusion",
201
+ "text": "EGFR mutation",
202
+ "category": "medical",
203
+ },
204
+ {
205
+ "criterion_id": "inc_2",
206
+ "type": "inclusion",
207
+ "text": "Age >= 18",
208
+ "category": "structural",
209
+ },
210
+ ]
211
+ )
212
+ MockExtractor.return_value.evaluate_medical_criterion = AsyncMock(
213
+ return_value={"decision": "met", "reasoning": "OK", "confidence": 0.9}
214
+ )
215
+ MockPlanner.return_value.evaluate_structural_criterion = AsyncMock(
216
+ return_value={"decision": "met", "reasoning": "OK", "confidence": 0.99}
217
+ )
218
+ MockPlanner.return_value.aggregate_assessments = AsyncMock(
219
+ return_value=mock_ledger
220
+ )
221
+
222
+ result = await evaluate_trial_eligibility.function(
223
+ mock_context,
224
+ patient_profile=json.dumps({"patient_id": "P001"}),
225
+ trial_candidate=json.dumps({"nct_id": "NCT001"}),
226
+ )
227
+
228
+ assert result.data["overall_assessment"] == "likely_eligible"
229
+ assert result.metadata["criteria_count"] == 2
230
+ MockExtractor.return_value.evaluate_medical_criterion.assert_called_once()
231
+ MockPlanner.return_value.evaluate_structural_criterion.assert_called_once()
232
+
233
+
234
+ class TestAnalyzeGaps:
235
+ """Test analyze_gaps tool."""
236
+
237
+ @pytest.mark.asyncio
238
+ async def test_calls_gemini_gap_analysis(self, mock_context):
239
+ """Should call GeminiPlanner.analyze_gaps."""
240
+ mock_gaps = [
241
+ {
242
+ "description": "Brain MRI needed",
243
+ "recommended_action": "Upload MRI",
244
+ "clinical_importance": "high",
245
+ "affected_trial_count": 2,
246
+ }
247
+ ]
248
+
249
+ with patch(
250
+ "trialpath.services.gemini_planner.GeminiPlanner"
251
+ ) as MockPlanner:
252
+ MockPlanner.return_value.analyze_gaps = AsyncMock(return_value=mock_gaps)
253
+
254
+ result = await analyze_gaps.function(
255
+ mock_context,
256
+ patient_profile=json.dumps({}),
257
+ eligibility_ledgers=json.dumps([]),
258
+ )
259
+
260
+ assert result.data["count"] == 1
261
+ assert result.data["gaps"][0]["clinical_importance"] == "high"
262
+
263
+
264
+ class TestAllToolsExported:
265
+ """Test ALL_TOOLS list completeness."""
266
+
267
+ def test_all_tools_has_7_entries(self):
268
+ """ALL_TOOLS should contain exactly 7 tools."""
269
+ assert len(ALL_TOOLS) == 7
270
+
271
+ def test_all_tools_are_tool_entries(self):
272
+ """Each item in ALL_TOOLS should be a ToolEntry."""
273
+ from parlant.sdk import ToolEntry
274
+
275
+ for t in ALL_TOOLS:
276
+ assert isinstance(t, ToolEntry), f"{t} is not a ToolEntry"