yakilee's picture
feat: wire real API pipeline with demo mode and Gemini truncation fix
94adbfa
"""Parlant tool definitions for the TrialPath agent."""
import json
from parlant.sdk import ToolContext, ToolResult, tool
from trialpath.config import (
GEMINI_API_KEY,
GEMINI_MODEL,
HF_TOKEN,
MCP_URL,
MEDGEMMA_ENDPOINT_URL,
)
# ---------------------------------------------------------------------------
# Lazy singletons — one instance per service, reused across tool calls.
# ---------------------------------------------------------------------------
_extractor = None
_planner = None
_mcp_client = None
def _get_extractor():
global _extractor
if _extractor is None:
from trialpath.services.medgemma_extractor import MedGemmaExtractor
_extractor = MedGemmaExtractor(
endpoint_url=MEDGEMMA_ENDPOINT_URL,
hf_token=HF_TOKEN,
)
return _extractor
def _get_planner():
global _planner
if _planner is None:
from trialpath.services.gemini_planner import GeminiPlanner
_planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY)
return _planner
def _get_mcp_client():
global _mcp_client
if _mcp_client is None:
from trialpath.services.mcp_client import ClinicalTrialsMCPClient
_mcp_client = ClinicalTrialsMCPClient(mcp_url=MCP_URL)
return _mcp_client
# ---------------------------------------------------------------------------
# Tools
# ---------------------------------------------------------------------------
@tool
async def extract_patient_profile(
context: ToolContext,
document_urls: str,
metadata: str,
) -> ToolResult:
"""Extract a structured patient profile from uploaded medical documents.
Args:
context: Parlant tool context.
document_urls: JSON list of document file paths.
metadata: JSON object with known patient metadata (age, sex).
"""
extractor = _get_extractor()
urls = json.loads(document_urls)
meta = json.loads(metadata)
profile = await extractor.extract(urls, meta)
return ToolResult(
data=profile,
metadata={"source": "medgemma", "doc_count": len(urls)},
)
@tool
async def generate_search_anchors(
context: ToolContext,
patient_profile: str,
) -> ToolResult:
"""Generate search parameters from a patient profile for ClinicalTrials.gov.
Args:
context: Parlant tool context.
patient_profile: JSON string of PatientProfile data.
"""
planner = _get_planner()
profile = json.loads(patient_profile)
anchors = await planner.generate_search_anchors(profile)
return ToolResult(
data=anchors.model_dump(),
metadata={"source": "gemini"},
)
@tool
async def search_clinical_trials(
context: ToolContext,
search_anchors: str,
) -> ToolResult:
"""Search ClinicalTrials.gov for matching trials using search anchors.
Args:
context: Parlant tool context.
search_anchors: JSON string of SearchAnchors data.
"""
from trialpath.models.search_anchors import SearchAnchors
client = _get_mcp_client()
anchors = SearchAnchors.model_validate(json.loads(search_anchors))
try:
raw_studies = await client.search(anchors)
except Exception:
# Fallback to direct ClinicalTrials.gov API v2 when MCP unavailable
raw_studies = await client.search_direct(anchors)
from trialpath.services.mcp_client import ClinicalTrialsMCPClient
trials = [ClinicalTrialsMCPClient.normalize_trial(s).model_dump() for s in raw_studies]
return ToolResult(
data={"trials": trials, "count": len(trials)},
metadata={"source": "clinicaltrials_mcp"},
)
@tool
async def refine_search_query(
context: ToolContext,
search_anchors: str,
result_count: str,
) -> ToolResult:
"""Refine search parameters when too many results returned.
Args:
context: Parlant tool context.
search_anchors: JSON string of current SearchAnchors.
result_count: Number of results from last search.
"""
from trialpath.models.search_anchors import SearchAnchors
planner = _get_planner()
anchors = SearchAnchors.model_validate(json.loads(search_anchors))
refined = await planner.refine_search(anchors, int(result_count))
return ToolResult(
data=refined.model_dump(),
metadata={"action": "refine", "prev_count": int(result_count)},
)
@tool
async def relax_search_query(
context: ToolContext,
search_anchors: str,
result_count: str,
) -> ToolResult:
"""Relax search parameters when too few results returned.
Args:
context: Parlant tool context.
search_anchors: JSON string of current SearchAnchors.
result_count: Number of results from last search.
"""
from trialpath.models.search_anchors import SearchAnchors
planner = _get_planner()
anchors = SearchAnchors.model_validate(json.loads(search_anchors))
relaxed = await planner.relax_search(anchors, int(result_count))
return ToolResult(
data=relaxed.model_dump(),
metadata={"action": "relax", "prev_count": int(result_count)},
)
@tool
async def evaluate_trial_eligibility(
context: ToolContext,
patient_profile: str,
trial_candidate: str,
) -> ToolResult:
"""Evaluate patient eligibility for a clinical trial using dual-model approach.
Medical criteria evaluated by MedGemma, structural by Gemini.
Args:
context: Parlant tool context.
patient_profile: JSON string of PatientProfile data.
trial_candidate: JSON string of TrialCandidate data.
"""
profile = json.loads(patient_profile)
trial = json.loads(trial_candidate)
planner = _get_planner()
extractor = _get_extractor()
# Step 1: Slice criteria into atomic items
criteria = await planner.slice_criteria(trial)
# Step 2: Evaluate each criterion with appropriate model
assessments = []
for criterion in criteria:
if criterion.get("category") == "medical":
result = await extractor.evaluate_medical_criterion(criterion["text"], profile, [])
else:
result = await planner.evaluate_structural_criterion(criterion["text"], profile)
assessments.append({**criterion, **result})
# Step 3: Aggregate into overall assessment
ledger = await planner.aggregate_assessments(profile, trial, assessments)
return ToolResult(
data=ledger.model_dump(),
metadata={"source": "dual_model", "criteria_count": len(criteria)},
)
@tool
async def analyze_gaps(
context: ToolContext,
patient_profile: str,
eligibility_ledgers: str,
) -> ToolResult:
"""Analyze eligibility gaps across all evaluated trials.
Args:
context: Parlant tool context.
patient_profile: JSON string of PatientProfile data.
eligibility_ledgers: JSON list of EligibilityLedger data.
"""
planner = _get_planner()
profile = json.loads(patient_profile)
ledgers = json.loads(eligibility_ledgers)
gaps = await planner.analyze_gaps(profile, ledgers)
return ToolResult(
data={"gaps": gaps, "count": len(gaps)},
metadata={"source": "gemini"},
)
ALL_TOOLS = [
extract_patient_profile,
generate_search_anchors,
search_clinical_trials,
refine_search_query,
relax_search_query,
evaluate_trial_eligibility,
analyze_gaps,
]