"""Parlant tool definitions for the TrialPath agent.""" import json from parlant.sdk import ToolContext, ToolResult, tool from trialpath.config import ( GEMINI_API_KEY, GEMINI_MODEL, HF_TOKEN, MCP_URL, MEDGEMMA_ENDPOINT_URL, ) # --------------------------------------------------------------------------- # Lazy singletons — one instance per service, reused across tool calls. # --------------------------------------------------------------------------- _extractor = None _planner = None _mcp_client = None def _get_extractor(): global _extractor if _extractor is None: from trialpath.services.medgemma_extractor import MedGemmaExtractor _extractor = MedGemmaExtractor( endpoint_url=MEDGEMMA_ENDPOINT_URL, hf_token=HF_TOKEN, ) return _extractor def _get_planner(): global _planner if _planner is None: from trialpath.services.gemini_planner import GeminiPlanner _planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY) return _planner def _get_mcp_client(): global _mcp_client if _mcp_client is None: from trialpath.services.mcp_client import ClinicalTrialsMCPClient _mcp_client = ClinicalTrialsMCPClient(mcp_url=MCP_URL) return _mcp_client # --------------------------------------------------------------------------- # Tools # --------------------------------------------------------------------------- @tool async def extract_patient_profile( context: ToolContext, document_urls: str, metadata: str, ) -> ToolResult: """Extract a structured patient profile from uploaded medical documents. Args: context: Parlant tool context. document_urls: JSON list of document file paths. metadata: JSON object with known patient metadata (age, sex). """ extractor = _get_extractor() urls = json.loads(document_urls) meta = json.loads(metadata) profile = await extractor.extract(urls, meta) return ToolResult( data=profile, metadata={"source": "medgemma", "doc_count": len(urls)}, ) @tool async def generate_search_anchors( context: ToolContext, patient_profile: str, ) -> ToolResult: """Generate search parameters from a patient profile for ClinicalTrials.gov. Args: context: Parlant tool context. patient_profile: JSON string of PatientProfile data. """ planner = _get_planner() profile = json.loads(patient_profile) anchors = await planner.generate_search_anchors(profile) return ToolResult( data=anchors.model_dump(), metadata={"source": "gemini"}, ) @tool async def search_clinical_trials( context: ToolContext, search_anchors: str, ) -> ToolResult: """Search ClinicalTrials.gov for matching trials using search anchors. Args: context: Parlant tool context. search_anchors: JSON string of SearchAnchors data. """ from trialpath.models.search_anchors import SearchAnchors client = _get_mcp_client() anchors = SearchAnchors.model_validate(json.loads(search_anchors)) try: raw_studies = await client.search(anchors) except Exception: # Fallback to direct ClinicalTrials.gov API v2 when MCP unavailable raw_studies = await client.search_direct(anchors) from trialpath.services.mcp_client import ClinicalTrialsMCPClient trials = [ClinicalTrialsMCPClient.normalize_trial(s).model_dump() for s in raw_studies] return ToolResult( data={"trials": trials, "count": len(trials)}, metadata={"source": "clinicaltrials_mcp"}, ) @tool async def refine_search_query( context: ToolContext, search_anchors: str, result_count: str, ) -> ToolResult: """Refine search parameters when too many results returned. Args: context: Parlant tool context. search_anchors: JSON string of current SearchAnchors. result_count: Number of results from last search. """ from trialpath.models.search_anchors import SearchAnchors planner = _get_planner() anchors = SearchAnchors.model_validate(json.loads(search_anchors)) refined = await planner.refine_search(anchors, int(result_count)) return ToolResult( data=refined.model_dump(), metadata={"action": "refine", "prev_count": int(result_count)}, ) @tool async def relax_search_query( context: ToolContext, search_anchors: str, result_count: str, ) -> ToolResult: """Relax search parameters when too few results returned. Args: context: Parlant tool context. search_anchors: JSON string of current SearchAnchors. result_count: Number of results from last search. """ from trialpath.models.search_anchors import SearchAnchors planner = _get_planner() anchors = SearchAnchors.model_validate(json.loads(search_anchors)) relaxed = await planner.relax_search(anchors, int(result_count)) return ToolResult( data=relaxed.model_dump(), metadata={"action": "relax", "prev_count": int(result_count)}, ) @tool async def evaluate_trial_eligibility( context: ToolContext, patient_profile: str, trial_candidate: str, ) -> ToolResult: """Evaluate patient eligibility for a clinical trial using dual-model approach. Medical criteria evaluated by MedGemma, structural by Gemini. Args: context: Parlant tool context. patient_profile: JSON string of PatientProfile data. trial_candidate: JSON string of TrialCandidate data. """ profile = json.loads(patient_profile) trial = json.loads(trial_candidate) planner = _get_planner() extractor = _get_extractor() # Step 1: Slice criteria into atomic items criteria = await planner.slice_criteria(trial) # Step 2: Evaluate each criterion with appropriate model assessments = [] for criterion in criteria: if criterion.get("category") == "medical": result = await extractor.evaluate_medical_criterion(criterion["text"], profile, []) else: result = await planner.evaluate_structural_criterion(criterion["text"], profile) assessments.append({**criterion, **result}) # Step 3: Aggregate into overall assessment ledger = await planner.aggregate_assessments(profile, trial, assessments) return ToolResult( data=ledger.model_dump(), metadata={"source": "dual_model", "criteria_count": len(criteria)}, ) @tool async def analyze_gaps( context: ToolContext, patient_profile: str, eligibility_ledgers: str, ) -> ToolResult: """Analyze eligibility gaps across all evaluated trials. Args: context: Parlant tool context. patient_profile: JSON string of PatientProfile data. eligibility_ledgers: JSON list of EligibilityLedger data. """ planner = _get_planner() profile = json.loads(patient_profile) ledgers = json.loads(eligibility_ledgers) gaps = await planner.analyze_gaps(profile, ledgers) return ToolResult( data={"gaps": gaps, "count": len(gaps)}, metadata={"source": "gemini"}, ) ALL_TOOLS = [ extract_patient_profile, generate_search_anchors, search_clinical_trials, refine_search_query, relax_search_query, evaluate_trial_eligibility, analyze_gaps, ]