| | |
| | """ |
| | Track C — Run iterative refinement experiments against MedQA. |
| | |
| | Usage: |
| | cd src/backend |
| | python -m tracks.iterative.run_iterative # all configs |
| | python -m tracks.iterative.run_iterative --config C1_3rounds # single config |
| | python -m tracks.iterative.run_iterative --max-cases 10 # quick test |
| | |
| | Each config runs the full baseline pipeline first, then feeds the initial |
| | reasoning through N self-critique iterations. Results include per-iteration |
| | accuracy AND cost, enabling cost/benefit charts. |
| | """ |
| | from __future__ import annotations |
| |
|
| | import asyncio |
| | import json |
| | import logging |
| | import sys |
| | import time |
| | from datetime import datetime, timezone |
| | from pathlib import Path |
| | from typing import List, Optional |
| |
|
| | BACKEND_DIR = Path(__file__).resolve().parent.parent.parent |
| | if str(BACKEND_DIR) not in sys.path: |
| | sys.path.insert(0, str(BACKEND_DIR)) |
| |
|
| | from app.agent.orchestrator import Orchestrator |
| | from app.models.schemas import ( |
| | CaseSubmission, |
| | CDSReport, |
| | AgentStepStatus, |
| | ClinicalReasoningResult, |
| | ) |
| | from tracks.iterative.config import CONFIGS, IterativeConfig |
| | from tracks.iterative.refiner import IterativeRefiner |
| | from tracks.shared.cost_tracker import CostLedger |
| | from validation.base import ( |
| | ValidationCase, |
| | ValidationResult, |
| | ValidationSummary, |
| | diagnosis_in_differential, |
| | run_cds_pipeline, |
| | save_results, |
| | print_summary, |
| | ) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | RESULTS_DIR = Path(__file__).resolve().parent / "results" |
| | MEDQA_PATH = BACKEND_DIR / "validation" / "data" / "medqa_test.jsonl" |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | async def run_case_iterative( |
| | case: ValidationCase, |
| | config: IterativeConfig, |
| | ledger: CostLedger, |
| | ) -> ValidationResult: |
| | """ |
| | Run one case through: |
| | 1. The baseline pipeline (Track A) to get initial reasoning |
| | 2. The iterative refinement loop (Track C) |
| | 3. Re-synthesize the final report with the refined differential |
| | """ |
| | t0 = time.monotonic() |
| |
|
| | |
| | state, report, error = await run_cds_pipeline( |
| | patient_text=case.input_text, |
| | include_drug_check=True, |
| | include_guidelines=True, |
| | ) |
| |
|
| | if error or not state or not state.clinical_reasoning or not state.patient_profile: |
| | return ValidationResult( |
| | case_id=case.case_id, |
| | source_dataset=f"trackC_{config.config_id}", |
| | success=False, |
| | scores={}, |
| | pipeline_time_ms=int((time.monotonic() - t0) * 1000), |
| | error=error or "Baseline pipeline failed to produce reasoning", |
| | ) |
| |
|
| | |
| | refiner = IterativeRefiner(config, ledger) |
| | refined_reasoning, history = await refiner.refine( |
| | profile=state.patient_profile, |
| | initial_reasoning=state.clinical_reasoning, |
| | ) |
| |
|
| | |
| | |
| | |
| | from app.tools.synthesis import SynthesisTool |
| | synth = SynthesisTool() |
| | try: |
| | refined_report = await synth.run( |
| | patient_profile=state.patient_profile, |
| | clinical_reasoning=refined_reasoning, |
| | drug_interactions=state.drug_interactions, |
| | guideline_retrieval=state.guideline_retrieval, |
| | conflict_detection=state.conflict_detection, |
| | ) |
| | except Exception as e: |
| | refined_report = report |
| | logger.warning(f"Re-synthesis failed, using baseline report: {e}") |
| |
|
| | elapsed_ms = int((time.monotonic() - t0) * 1000) |
| |
|
| | |
| | scores: dict = {} |
| | details: dict = {"iterations": len(history) - 1} |
| |
|
| | if "answer" in case.ground_truth: |
| | gt = case.ground_truth["answer"] |
| |
|
| | |
| | if report: |
| | b_found, b_rank, b_loc = diagnosis_in_differential(gt, report) |
| | scores["baseline_top1"] = 1.0 if (b_found and b_rank == 0) else 0.0 |
| | scores["baseline_mentioned"] = 1.0 if b_found else 0.0 |
| |
|
| | |
| | target_report = refined_report or report |
| | if target_report: |
| | r_found, r_rank, r_loc = diagnosis_in_differential(gt, target_report) |
| | scores["top1_accuracy"] = 1.0 if (r_found and r_rank == 0) else 0.0 |
| | scores["top3_accuracy"] = 1.0 if (r_found and r_rank < 3) else 0.0 |
| | scores["mentioned"] = 1.0 if r_found else 0.0 |
| | details["rank"] = r_rank |
| | details["match_location"] = r_loc |
| | details["improved"] = scores.get("top1_accuracy", 0) > scores.get("baseline_top1", 0) |
| |
|
| | |
| | details["per_iteration_top_dx"] = [ |
| | h.differential_diagnosis[0].diagnosis if h.differential_diagnosis else "?" |
| | for h in history |
| | ] |
| | details["cost_ledger"] = ledger.to_dict() |
| |
|
| | return ValidationResult( |
| | case_id=case.case_id, |
| | source_dataset=f"trackC_{config.config_id}", |
| | success=True, |
| | scores=scores, |
| | pipeline_time_ms=elapsed_ms, |
| | report_summary=(refined_report or report).patient_summary[:200] if (refined_report or report) else None, |
| | details=details, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | async def run_config( |
| | config: IterativeConfig, |
| | cases: List[ValidationCase], |
| | ) -> ValidationSummary: |
| | """Run all cases through the iterative config.""" |
| | results: List[ValidationResult] = [] |
| | start = time.monotonic() |
| |
|
| | for i, case in enumerate(cases, 1): |
| | logger.info(f" [{config.config_id}] case {i}/{len(cases)}: {case.case_id}") |
| | ledger = CostLedger(track_id=f"C_{config.config_id}") |
| | vr = await run_case_iterative(case, config, ledger) |
| | results.append(vr) |
| |
|
| | elapsed = time.monotonic() - start |
| | successful = [r for r in results if r.success] |
| |
|
| | metrics = {} |
| | for key in ("top1_accuracy", "top3_accuracy", "mentioned", "baseline_top1", "baseline_mentioned"): |
| | vals = [r.scores[key] for r in successful if key in r.scores] |
| | metrics[key] = sum(vals) / len(vals) if vals else 0.0 |
| | metrics["pipeline_success"] = len(successful) / len(results) if results else 0.0 |
| |
|
| | |
| | iters = [r.details.get("iterations", 0) for r in successful] |
| | metrics["avg_iterations"] = sum(iters) / len(iters) if iters else 0.0 |
| |
|
| | |
| | improved = [r for r in successful if r.details.get("improved")] |
| | metrics["improvement_rate"] = len(improved) / len(successful) if successful else 0.0 |
| |
|
| | return ValidationSummary( |
| | dataset=f"trackC_{config.config_id}", |
| | total_cases=len(results), |
| | successful_cases=len(successful), |
| | failed_cases=len(results) - len(successful), |
| | metrics=metrics, |
| | per_case=results, |
| | run_duration_sec=round(elapsed, 1), |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_medqa_cases(max_cases: Optional[int] = None) -> List[ValidationCase]: |
| | if not MEDQA_PATH.exists(): |
| | logger.error(f"MedQA data not found at {MEDQA_PATH}") |
| | return [] |
| | cases: List[ValidationCase] = [] |
| | with open(MEDQA_PATH, "r", encoding="utf-8") as f: |
| | for ln, line in enumerate(f, 1): |
| | if max_cases and len(cases) >= max_cases: |
| | break |
| | if not line.strip(): |
| | continue |
| | data = json.loads(line) |
| | cases.append(ValidationCase( |
| | case_id=data.get("id", f"medqa_{ln}"), |
| | source_dataset="medqa", |
| | input_text=data.get("question", data.get("input", "")), |
| | ground_truth={"answer": data.get("answer", data.get("target", ""))}, |
| | metadata=data.get("metadata", {}), |
| | )) |
| | return cases |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | async def main(): |
| | import argparse |
| | parser = argparse.ArgumentParser(description="Track C: Iterative refinement experiments") |
| | parser.add_argument("--config", type=str, default=None, help="Run a single config by ID") |
| | parser.add_argument("--max-cases", type=int, default=None, help="Limit cases per config") |
| | parser.add_argument("--quiet", action="store_true") |
| | args = parser.parse_args() |
| |
|
| | logging.basicConfig( |
| | level=logging.WARNING if args.quiet else logging.INFO, |
| | format="%(asctime)s %(levelname)s %(name)s: %(message)s", |
| | ) |
| |
|
| | |
| | from tracks.shared.endpoint_check import wait_for_endpoint |
| | if not await wait_for_endpoint(quiet=args.quiet): |
| | print("ABORT: MedGemma endpoint is not reachable. Resume it and try again.") |
| | sys.exit(1) |
| |
|
| | configs = CONFIGS |
| | if args.config: |
| | configs = [c for c in CONFIGS if c.config_id == args.config] |
| | if not configs: |
| | print(f"Unknown config: {args.config}") |
| | print(f"Available: {[c.config_id for c in CONFIGS]}") |
| | sys.exit(1) |
| |
|
| | cases = load_medqa_cases(args.max_cases) |
| | if not cases: |
| | print("No MedQA cases loaded") |
| | sys.exit(1) |
| | print(f"Loaded {len(cases)} MedQA cases\n") |
| |
|
| | RESULTS_DIR.mkdir(parents=True, exist_ok=True) |
| |
|
| | for cfg in configs: |
| | print(f"\n{'='*60}") |
| | print(f" Running config: {cfg.config_id}") |
| | print(f" {cfg.description}") |
| | print(f" Max iterations: {cfg.max_iterations}, convergence: {cfg.convergence_threshold}") |
| | print(f"{'='*60}") |
| |
|
| | summary = await run_config(cfg, cases) |
| |
|
| | ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") |
| | fname = f"trackC_{cfg.config_id}_{ts}.json" |
| | path = RESULTS_DIR / fname |
| | |
| | save_path = save_results(summary, filename=fname) |
| | if save_path != path: |
| | import shutil |
| | shutil.move(str(save_path), str(path)) |
| |
|
| | if not args.quiet: |
| | print_summary(summary) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | asyncio.run(main()) |
| |
|