TrialPath / app /services /direct_pipeline.py
yakilee's picture
refactor(ui): simplify pipeline progress display with st.progress bar
70f82ca
"""Direct API pipeline — calls real services without Parlant orchestration.
Used as fallback when Parlant server is unavailable, or as the primary
pipeline when TRIALPATH_USE_PARLANT=true but the server can't be reached.
Flow: PatientProfile -> SearchAnchors (Gemini) -> Trials (ClinicalTrials.gov)
-> Eligibility (Gemini) -> TrialCandidates + EligibilityLedgers
"""
from __future__ import annotations
import asyncio
import concurrent.futures
import queue
import time
from dataclasses import dataclass
import structlog
from trialpath.models.eligibility_ledger import EligibilityLedger
from trialpath.models.patient_profile import PatientProfile
from trialpath.models.trial_candidate import TrialCandidate
logger = structlog.get_logger("trialpath.pipeline")
_THREAD_POOL = concurrent.futures.ThreadPoolExecutor(max_workers=4)
_MAX_TRIALS_TO_EVALUATE = 5
@dataclass
class PipelineProgress:
"""A single progress update from the pipeline."""
step: str # e.g., "search_anchors", "search", "evaluate"
message: str # Human-readable status
detail: str = "" # Optional detail (e.g., "Trial 3/5: NCT12345678")
pct: float = 0.0 # 0.0 - 1.0 progress fraction
timestamp: float = 0.0
def __post_init__(self):
if not self.timestamp:
self.timestamp = time.time()
# Thread-safe progress queue — pipeline writes, UI reads
_progress_queue: queue.Queue[PipelineProgress] = queue.Queue()
def get_progress_queue() -> queue.Queue[PipelineProgress]:
return _progress_queue
def _run_async(coro):
"""Run an async coroutine from synchronous Streamlit context.
No hard timeout — pipeline reports progress to UI instead.
"""
future = _THREAD_POOL.submit(asyncio.run, coro)
return future.result()
async def _evaluate_one(
planner,
profile_dict: dict,
profile_id: str,
trial: TrialCandidate,
) -> EligibilityLedger | None:
"""Evaluate a single trial's eligibility. Returns None on failure."""
try:
ledger = await planner.evaluate_eligibility(
patient_profile=profile_dict,
trial_candidate=trial.model_dump(mode="json"),
)
ledger.patient_id = profile_id
ledger.nct_id = trial.nct_id
return ledger
except Exception:
logger.exception("Failed to evaluate eligibility for %s", trial.nct_id)
return None
async def _search_and_evaluate(
profile: PatientProfile,
progress_q: queue.Queue[PipelineProgress] | None = None,
) -> tuple[list[TrialCandidate], list[EligibilityLedger]]:
"""Run the full trial search and evaluation pipeline."""
from trialpath.config import GEMINI_API_KEY
from trialpath.services.gemini_planner import GeminiPlanner
from trialpath.services.mcp_client import ClinicalTrialsMCPClient
def _report(step: str, message: str, detail: str = "", pct: float = 0.0):
if progress_q:
progress_q.put(PipelineProgress(step=step, message=message, detail=detail, pct=pct))
planner = GeminiPlanner(api_key=GEMINI_API_KEY)
client = ClinicalTrialsMCPClient()
profile_dict = profile.model_dump(mode="json")
pipeline_start = time.monotonic()
# Step 1: Generate search anchors from profile
_report("search_anchors", "Generating search parameters with Gemini...", pct=0.1)
step_start = time.monotonic()
anchors = await planner.generate_search_anchors(profile_dict)
logger.info(
"step_complete",
step="search_anchors",
condition=anchors.condition,
duration_s=round(time.monotonic() - step_start, 2),
)
_report(
"search_anchors",
f"Search anchors ready: {anchors.condition}",
pct=0.2,
)
# Step 2: Search ClinicalTrials.gov (direct API)
_report("search", "Searching ClinicalTrials.gov...", pct=0.25)
step_start = time.monotonic()
raw_studies = await client.search_multi_variant(anchors)
logger.info(
"step_complete",
step="ct_gov_search",
raw_studies=len(raw_studies),
duration_s=round(time.monotonic() - step_start, 2),
)
_report("search", f"Found {len(raw_studies)} studies", pct=0.35)
# Step 3: Normalize to TrialCandidate models
_report("normalize", f"Processing {len(raw_studies)} studies...", pct=0.4)
trials: list[TrialCandidate] = []
for raw in raw_studies:
try:
trial = ClinicalTrialsMCPClient.normalize_trial(raw)
trials.append(trial)
except Exception:
logger.warning("Failed to normalize trial: %s", raw.get("nctId", "?"))
# Step 4: Evaluate eligibility for top N trials in parallel
evaluable = [t for t in trials if t.eligibility_text and t.eligibility_text.inclusion]
to_evaluate = evaluable[:_MAX_TRIALS_TO_EVALUATE]
logger.info("Evaluating eligibility for %d/%d trials", len(to_evaluate), len(trials))
total = len(to_evaluate)
_report(
"evaluate",
f"Evaluating eligibility for {total} trials in parallel...",
pct=0.45,
)
eval_tasks = [
_evaluate_one(planner, profile_dict, profile.patient_id, trial) for trial in to_evaluate
]
step_start = time.monotonic()
results = await asyncio.gather(*eval_tasks)
ledgers = [r for r in results if r is not None]
logger.info(
"step_complete",
step="evaluate_eligibility",
evaluated=len(ledgers),
attempted=total,
duration_s=round(time.monotonic() - step_start, 2),
)
_report("done", f"Evaluated {len(ledgers)}/{total} trials successfully", pct=0.95)
# Only return trials that have ledgers
evaluated_nct_ids = {lg.nct_id for lg in ledgers}
matched_trials = [t for t in trials if t.nct_id in evaluated_nct_ids]
total_elapsed = time.monotonic() - pipeline_start
logger.info(
"pipeline_complete",
total_trials=len(matched_trials),
total_ledgers=len(ledgers),
total_duration_s=round(total_elapsed, 2),
)
_report("done", "Pipeline complete!", pct=1.0)
return matched_trials, ledgers
def run_trial_search_and_evaluate(
profile: PatientProfile,
progress_q: queue.Queue[PipelineProgress] | None = None,
) -> tuple[list[TrialCandidate], list[EligibilityLedger]]:
"""Synchronous wrapper for the direct pipeline.
Safe to call from Streamlit page scripts.
Returns (trial_candidates, eligibility_ledgers).
"""
return _run_async(_search_and_evaluate(profile, progress_q=progress_q))