| | """Parlant tool definitions for the TrialPath agent.""" |
| |
|
| | import json |
| |
|
| | from parlant.sdk import ToolContext, ToolResult, tool |
| |
|
| | from trialpath.config import ( |
| | GEMINI_API_KEY, |
| | GEMINI_MODEL, |
| | HF_TOKEN, |
| | MCP_URL, |
| | MEDGEMMA_ENDPOINT_URL, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | _extractor = None |
| | _planner = None |
| | _mcp_client = None |
| |
|
| |
|
| | def _get_extractor(): |
| | global _extractor |
| | if _extractor is None: |
| | from trialpath.services.medgemma_extractor import MedGemmaExtractor |
| |
|
| | _extractor = MedGemmaExtractor( |
| | endpoint_url=MEDGEMMA_ENDPOINT_URL, |
| | hf_token=HF_TOKEN, |
| | ) |
| | return _extractor |
| |
|
| |
|
| | def _get_planner(): |
| | global _planner |
| | if _planner is None: |
| | from trialpath.services.gemini_planner import GeminiPlanner |
| |
|
| | _planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY) |
| | return _planner |
| |
|
| |
|
| | def _get_mcp_client(): |
| | global _mcp_client |
| | if _mcp_client is None: |
| | from trialpath.services.mcp_client import ClinicalTrialsMCPClient |
| |
|
| | _mcp_client = ClinicalTrialsMCPClient(mcp_url=MCP_URL) |
| | return _mcp_client |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | @tool |
| | async def extract_patient_profile( |
| | context: ToolContext, |
| | document_urls: str, |
| | metadata: str, |
| | ) -> ToolResult: |
| | """Extract a structured patient profile from uploaded medical documents. |
| | |
| | Args: |
| | context: Parlant tool context. |
| | document_urls: JSON list of document file paths. |
| | metadata: JSON object with known patient metadata (age, sex). |
| | """ |
| | extractor = _get_extractor() |
| | urls = json.loads(document_urls) |
| | meta = json.loads(metadata) |
| | profile = await extractor.extract(urls, meta) |
| |
|
| | return ToolResult( |
| | data=profile, |
| | metadata={"source": "medgemma", "doc_count": len(urls)}, |
| | ) |
| |
|
| |
|
| | @tool |
| | async def generate_search_anchors( |
| | context: ToolContext, |
| | patient_profile: str, |
| | ) -> ToolResult: |
| | """Generate search parameters from a patient profile for ClinicalTrials.gov. |
| | |
| | Args: |
| | context: Parlant tool context. |
| | patient_profile: JSON string of PatientProfile data. |
| | """ |
| | planner = _get_planner() |
| | profile = json.loads(patient_profile) |
| | anchors = await planner.generate_search_anchors(profile) |
| |
|
| | return ToolResult( |
| | data=anchors.model_dump(), |
| | metadata={"source": "gemini"}, |
| | ) |
| |
|
| |
|
| | @tool |
| | async def search_clinical_trials( |
| | context: ToolContext, |
| | search_anchors: str, |
| | ) -> ToolResult: |
| | """Search ClinicalTrials.gov for matching trials using search anchors. |
| | |
| | Args: |
| | context: Parlant tool context. |
| | search_anchors: JSON string of SearchAnchors data. |
| | """ |
| | from trialpath.models.search_anchors import SearchAnchors |
| |
|
| | client = _get_mcp_client() |
| | anchors = SearchAnchors.model_validate(json.loads(search_anchors)) |
| | try: |
| | raw_studies = await client.search(anchors) |
| | except Exception: |
| | |
| | raw_studies = await client.search_direct(anchors) |
| |
|
| | from trialpath.services.mcp_client import ClinicalTrialsMCPClient |
| |
|
| | trials = [ClinicalTrialsMCPClient.normalize_trial(s).model_dump() for s in raw_studies] |
| |
|
| | return ToolResult( |
| | data={"trials": trials, "count": len(trials)}, |
| | metadata={"source": "clinicaltrials_mcp"}, |
| | ) |
| |
|
| |
|
| | @tool |
| | async def refine_search_query( |
| | context: ToolContext, |
| | search_anchors: str, |
| | result_count: str, |
| | ) -> ToolResult: |
| | """Refine search parameters when too many results returned. |
| | |
| | Args: |
| | context: Parlant tool context. |
| | search_anchors: JSON string of current SearchAnchors. |
| | result_count: Number of results from last search. |
| | """ |
| | from trialpath.models.search_anchors import SearchAnchors |
| |
|
| | planner = _get_planner() |
| | anchors = SearchAnchors.model_validate(json.loads(search_anchors)) |
| | refined = await planner.refine_search(anchors, int(result_count)) |
| |
|
| | return ToolResult( |
| | data=refined.model_dump(), |
| | metadata={"action": "refine", "prev_count": int(result_count)}, |
| | ) |
| |
|
| |
|
| | @tool |
| | async def relax_search_query( |
| | context: ToolContext, |
| | search_anchors: str, |
| | result_count: str, |
| | ) -> ToolResult: |
| | """Relax search parameters when too few results returned. |
| | |
| | Args: |
| | context: Parlant tool context. |
| | search_anchors: JSON string of current SearchAnchors. |
| | result_count: Number of results from last search. |
| | """ |
| | from trialpath.models.search_anchors import SearchAnchors |
| |
|
| | planner = _get_planner() |
| | anchors = SearchAnchors.model_validate(json.loads(search_anchors)) |
| | relaxed = await planner.relax_search(anchors, int(result_count)) |
| |
|
| | return ToolResult( |
| | data=relaxed.model_dump(), |
| | metadata={"action": "relax", "prev_count": int(result_count)}, |
| | ) |
| |
|
| |
|
| | @tool |
| | async def evaluate_trial_eligibility( |
| | context: ToolContext, |
| | patient_profile: str, |
| | trial_candidate: str, |
| | ) -> ToolResult: |
| | """Evaluate patient eligibility for a clinical trial using dual-model approach. |
| | |
| | Medical criteria evaluated by MedGemma, structural by Gemini. |
| | |
| | Args: |
| | context: Parlant tool context. |
| | patient_profile: JSON string of PatientProfile data. |
| | trial_candidate: JSON string of TrialCandidate data. |
| | """ |
| | profile = json.loads(patient_profile) |
| | trial = json.loads(trial_candidate) |
| |
|
| | planner = _get_planner() |
| | extractor = _get_extractor() |
| |
|
| | |
| | criteria = await planner.slice_criteria(trial) |
| |
|
| | |
| | assessments = [] |
| | for criterion in criteria: |
| | if criterion.get("category") == "medical": |
| | result = await extractor.evaluate_medical_criterion(criterion["text"], profile, []) |
| | else: |
| | result = await planner.evaluate_structural_criterion(criterion["text"], profile) |
| | assessments.append({**criterion, **result}) |
| |
|
| | |
| | ledger = await planner.aggregate_assessments(profile, trial, assessments) |
| |
|
| | return ToolResult( |
| | data=ledger.model_dump(), |
| | metadata={"source": "dual_model", "criteria_count": len(criteria)}, |
| | ) |
| |
|
| |
|
| | @tool |
| | async def analyze_gaps( |
| | context: ToolContext, |
| | patient_profile: str, |
| | eligibility_ledgers: str, |
| | ) -> ToolResult: |
| | """Analyze eligibility gaps across all evaluated trials. |
| | |
| | Args: |
| | context: Parlant tool context. |
| | patient_profile: JSON string of PatientProfile data. |
| | eligibility_ledgers: JSON list of EligibilityLedger data. |
| | """ |
| | planner = _get_planner() |
| | profile = json.loads(patient_profile) |
| | ledgers = json.loads(eligibility_ledgers) |
| | gaps = await planner.analyze_gaps(profile, ledgers) |
| |
|
| | return ToolResult( |
| | data={"gaps": gaps, "count": len(gaps)}, |
| | metadata={"source": "gemini"}, |
| | ) |
| |
|
| |
|
| | ALL_TOOLS = [ |
| | extract_patient_profile, |
| | generate_search_anchors, |
| | search_clinical_trials, |
| | refine_search_query, |
| | relax_search_query, |
| | evaluate_trial_eligibility, |
| | analyze_gaps, |
| | ] |
| |
|