| """Direct API pipeline — calls real services without Parlant orchestration. |
| |
| Used as fallback when Parlant server is unavailable, or as the primary |
| pipeline when TRIALPATH_USE_PARLANT=true but the server can't be reached. |
| |
| Flow: PatientProfile -> SearchAnchors (Gemini) -> Trials (ClinicalTrials.gov) |
| -> Eligibility (Gemini) -> TrialCandidates + EligibilityLedgers |
| """ |
|
|
| from __future__ import annotations |
|
|
| import asyncio |
| import concurrent.futures |
| import queue |
| import time |
| from dataclasses import dataclass |
|
|
| import structlog |
|
|
| from trialpath.models.eligibility_ledger import EligibilityLedger |
| from trialpath.models.patient_profile import PatientProfile |
| from trialpath.models.trial_candidate import TrialCandidate |
|
|
| logger = structlog.get_logger("trialpath.pipeline") |
|
|
| _THREAD_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=4) |
|
|
| _MAX_TRIALS_TO_EVALUATE = 5 |
|
|
|
|
| @dataclass |
| class PipelineProgress: |
| """A single progress update from the pipeline.""" |
|
|
| step: str |
| message: str |
| detail: str = "" |
| pct: float = 0.0 |
| timestamp: float = 0.0 |
|
|
| def __post_init__(self): |
| if not self.timestamp: |
| self.timestamp = time.time() |
|
|
|
|
| |
| _progress_queue: queue.Queue[PipelineProgress] = queue.Queue() |
|
|
|
|
| def get_progress_queue() -> queue.Queue[PipelineProgress]: |
| return _progress_queue |
|
|
|
|
| def _run_async(coro): |
| """Run an async coroutine from synchronous Streamlit context. |
| |
| No hard timeout — pipeline reports progress to UI instead. |
| """ |
| future = _THREAD_POOL.submit(asyncio.run, coro) |
| return future.result() |
|
|
|
|
| async def _evaluate_one( |
| planner, |
| profile_dict: dict, |
| profile_id: str, |
| trial: TrialCandidate, |
| ) -> EligibilityLedger | None: |
| """Evaluate a single trial's eligibility. Returns None on failure.""" |
| try: |
| ledger = await planner.evaluate_eligibility( |
| patient_profile=profile_dict, |
| trial_candidate=trial.model_dump(mode="json"), |
| ) |
| ledger.patient_id = profile_id |
| ledger.nct_id = trial.nct_id |
| return ledger |
| except Exception: |
| logger.exception("Failed to evaluate eligibility for %s", trial.nct_id) |
| return None |
|
|
|
|
| async def _search_and_evaluate( |
| profile: PatientProfile, |
| progress_q: queue.Queue[PipelineProgress] | None = None, |
| ) -> tuple[list[TrialCandidate], list[EligibilityLedger]]: |
| """Run the full trial search and evaluation pipeline.""" |
| from trialpath.config import GEMINI_API_KEY |
| from trialpath.services.gemini_planner import GeminiPlanner |
| from trialpath.services.mcp_client import ClinicalTrialsMCPClient |
|
|
| def _report(step: str, message: str, detail: str = "", pct: float = 0.0): |
| if progress_q: |
| progress_q.put(PipelineProgress(step=step, message=message, detail=detail, pct=pct)) |
|
|
| planner = GeminiPlanner(api_key=GEMINI_API_KEY) |
| client = ClinicalTrialsMCPClient() |
|
|
| profile_dict = profile.model_dump(mode="json") |
|
|
| pipeline_start = time.monotonic() |
|
|
| |
| _report("search_anchors", "Generating search parameters with Gemini...", pct=0.1) |
| step_start = time.monotonic() |
| anchors = await planner.generate_search_anchors(profile_dict) |
| logger.info( |
| "step_complete", |
| step="search_anchors", |
| condition=anchors.condition, |
| duration_s=round(time.monotonic() - step_start, 2), |
| ) |
| _report( |
| "search_anchors", |
| f"Search anchors ready: {anchors.condition}", |
| pct=0.2, |
| ) |
|
|
| |
| _report("search", "Searching ClinicalTrials.gov...", pct=0.25) |
| step_start = time.monotonic() |
| raw_studies = await client.search_multi_variant(anchors) |
| logger.info( |
| "step_complete", |
| step="ct_gov_search", |
| raw_studies=len(raw_studies), |
| duration_s=round(time.monotonic() - step_start, 2), |
| ) |
| _report("search", f"Found {len(raw_studies)} studies", pct=0.35) |
|
|
| |
| _report("normalize", f"Processing {len(raw_studies)} studies...", pct=0.4) |
| trials: list[TrialCandidate] = [] |
| for raw in raw_studies: |
| try: |
| trial = ClinicalTrialsMCPClient.normalize_trial(raw) |
| trials.append(trial) |
| except Exception: |
| logger.warning("Failed to normalize trial: %s", raw.get("nctId", "?")) |
|
|
| |
| evaluable = [t for t in trials if t.eligibility_text and t.eligibility_text.inclusion] |
| to_evaluate = evaluable[:_MAX_TRIALS_TO_EVALUATE] |
| logger.info("Evaluating eligibility for %d/%d trials", len(to_evaluate), len(trials)) |
|
|
| total = len(to_evaluate) |
| _report( |
| "evaluate", |
| f"Evaluating eligibility for {total} trials in parallel...", |
| pct=0.45, |
| ) |
|
|
| eval_tasks = [ |
| _evaluate_one(planner, profile_dict, profile.patient_id, trial) for trial in to_evaluate |
| ] |
| step_start = time.monotonic() |
| results = await asyncio.gather(*eval_tasks) |
| ledgers = [r for r in results if r is not None] |
| logger.info( |
| "step_complete", |
| step="evaluate_eligibility", |
| evaluated=len(ledgers), |
| attempted=total, |
| duration_s=round(time.monotonic() - step_start, 2), |
| ) |
|
|
| _report("done", f"Evaluated {len(ledgers)}/{total} trials successfully", pct=0.95) |
|
|
| |
| evaluated_nct_ids = {lg.nct_id for lg in ledgers} |
| matched_trials = [t for t in trials if t.nct_id in evaluated_nct_ids] |
|
|
| total_elapsed = time.monotonic() - pipeline_start |
| logger.info( |
| "pipeline_complete", |
| total_trials=len(matched_trials), |
| total_ledgers=len(ledgers), |
| total_duration_s=round(total_elapsed, 2), |
| ) |
|
|
| _report("done", "Pipeline complete!", pct=1.0) |
| return matched_trials, ledgers |
|
|
|
|
| def run_trial_search_and_evaluate( |
| profile: PatientProfile, |
| progress_q: queue.Queue[PipelineProgress] | None = None, |
| ) -> tuple[list[TrialCandidate], list[EligibilityLedger]]: |
| """Synchronous wrapper for the direct pipeline. |
| |
| Safe to call from Streamlit page scripts. |
| Returns (trial_candidates, eligibility_ledgers). |
| """ |
| return _run_async(_search_and_evaluate(profile, progress_q=progress_q)) |
|
|