bshepp
Implement validation pipeline fixes (P1-P7) and experimental track system
28f1212
# [Track C: Iterative Refinement]
"""
Track C — Run iterative refinement experiments against MedQA.
Usage:
cd src/backend
python -m tracks.iterative.run_iterative # all configs
python -m tracks.iterative.run_iterative --config C1_3rounds # single config
python -m tracks.iterative.run_iterative --max-cases 10 # quick test
Each config runs the full baseline pipeline first, then feeds the initial
reasoning through N self-critique iterations. Results include per-iteration
accuracy AND cost, enabling cost/benefit charts.
"""
from __future__ import annotations
import asyncio
import json
import logging
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Optional
BACKEND_DIR = Path(__file__).resolve().parent.parent.parent
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
from app.agent.orchestrator import Orchestrator
from app.models.schemas import (
CaseSubmission,
CDSReport,
AgentStepStatus,
ClinicalReasoningResult,
)
from tracks.iterative.config import CONFIGS, IterativeConfig
from tracks.iterative.refiner import IterativeRefiner
from tracks.shared.cost_tracker import CostLedger
from validation.base import (
ValidationCase,
ValidationResult,
ValidationSummary,
diagnosis_in_differential,
run_cds_pipeline,
save_results,
print_summary,
)
logger = logging.getLogger(__name__)
RESULTS_DIR = Path(__file__).resolve().parent / "results"
MEDQA_PATH = BACKEND_DIR / "validation" / "data" / "medqa_test.jsonl"
# ──────────────────────────────────────────────
# Per-case runner
# ──────────────────────────────────────────────
async def run_case_iterative(
case: ValidationCase,
config: IterativeConfig,
ledger: CostLedger,
) -> ValidationResult:
"""
Run one case through:
1. The baseline pipeline (Track A) to get initial reasoning
2. The iterative refinement loop (Track C)
3. Re-synthesize the final report with the refined differential
"""
t0 = time.monotonic()
# ── Step 1: Run baseline pipeline ──
state, report, error = await run_cds_pipeline(
patient_text=case.input_text,
include_drug_check=True,
include_guidelines=True,
)
if error or not state or not state.clinical_reasoning or not state.patient_profile:
return ValidationResult(
case_id=case.case_id,
source_dataset=f"trackC_{config.config_id}",
success=False,
scores={},
pipeline_time_ms=int((time.monotonic() - t0) * 1000),
error=error or "Baseline pipeline failed to produce reasoning",
)
# ── Step 2: Iterative refinement ──
refiner = IterativeRefiner(config, ledger)
refined_reasoning, history = await refiner.refine(
profile=state.patient_profile,
initial_reasoning=state.clinical_reasoning,
)
# ── Step 3: Re-synthesize with the refined differential ──
# Inject the refined reasoning back into the orchestrator state and
# re-run just the synthesis step
from app.tools.synthesis import SynthesisTool
synth = SynthesisTool()
try:
refined_report = await synth.run(
patient_profile=state.patient_profile,
clinical_reasoning=refined_reasoning,
drug_interactions=state.drug_interactions,
guideline_retrieval=state.guideline_retrieval,
conflict_detection=state.conflict_detection,
)
except Exception as e:
refined_report = report # Fall back to baseline report
logger.warning(f"Re-synthesis failed, using baseline report: {e}")
elapsed_ms = int((time.monotonic() - t0) * 1000)
# ── Score — compare baseline vs. refined ──
scores: dict = {}
details: dict = {"iterations": len(history) - 1} # subtract the initial
if "answer" in case.ground_truth:
gt = case.ground_truth["answer"]
# Score the baseline
if report:
b_found, b_rank, b_loc = diagnosis_in_differential(gt, report)
scores["baseline_top1"] = 1.0 if (b_found and b_rank == 0) else 0.0
scores["baseline_mentioned"] = 1.0 if b_found else 0.0
# Score the refined report
target_report = refined_report or report
if target_report:
r_found, r_rank, r_loc = diagnosis_in_differential(gt, target_report)
scores["top1_accuracy"] = 1.0 if (r_found and r_rank == 0) else 0.0
scores["top3_accuracy"] = 1.0 if (r_found and r_rank < 3) else 0.0
scores["mentioned"] = 1.0 if r_found else 0.0
details["rank"] = r_rank
details["match_location"] = r_loc
details["improved"] = scores.get("top1_accuracy", 0) > scores.get("baseline_top1", 0)
# Per-iteration differential snapshots (for cost/benefit charts)
details["per_iteration_top_dx"] = [
h.differential_diagnosis[0].diagnosis if h.differential_diagnosis else "?"
for h in history
]
details["cost_ledger"] = ledger.to_dict()
return ValidationResult(
case_id=case.case_id,
source_dataset=f"trackC_{config.config_id}",
success=True,
scores=scores,
pipeline_time_ms=elapsed_ms,
report_summary=(refined_report or report).patient_summary[:200] if (refined_report or report) else None,
details=details,
)
# ──────────────────────────────────────────────
# Experiment runner
# ──────────────────────────────────────────────
async def run_config(
config: IterativeConfig,
cases: List[ValidationCase],
) -> ValidationSummary:
"""Run all cases through the iterative config."""
results: List[ValidationResult] = []
start = time.monotonic()
for i, case in enumerate(cases, 1):
logger.info(f" [{config.config_id}] case {i}/{len(cases)}: {case.case_id}")
ledger = CostLedger(track_id=f"C_{config.config_id}")
vr = await run_case_iterative(case, config, ledger)
results.append(vr)
elapsed = time.monotonic() - start
successful = [r for r in results if r.success]
metrics = {}
for key in ("top1_accuracy", "top3_accuracy", "mentioned", "baseline_top1", "baseline_mentioned"):
vals = [r.scores[key] for r in successful if key in r.scores]
metrics[key] = sum(vals) / len(vals) if vals else 0.0
metrics["pipeline_success"] = len(successful) / len(results) if results else 0.0
# Average iterations used
iters = [r.details.get("iterations", 0) for r in successful]
metrics["avg_iterations"] = sum(iters) / len(iters) if iters else 0.0
# Improvement rate
improved = [r for r in successful if r.details.get("improved")]
metrics["improvement_rate"] = len(improved) / len(successful) if successful else 0.0
return ValidationSummary(
dataset=f"trackC_{config.config_id}",
total_cases=len(results),
successful_cases=len(successful),
failed_cases=len(results) - len(successful),
metrics=metrics,
per_case=results,
run_duration_sec=round(elapsed, 1),
)
# ──────────────────────────────────────────────
# Data loading (reuse from validation)
# ──────────────────────────────────────────────
def load_medqa_cases(max_cases: Optional[int] = None) -> List[ValidationCase]:
if not MEDQA_PATH.exists():
logger.error(f"MedQA data not found at {MEDQA_PATH}")
return []
cases: List[ValidationCase] = []
with open(MEDQA_PATH, "r", encoding="utf-8") as f:
for ln, line in enumerate(f, 1):
if max_cases and len(cases) >= max_cases:
break
if not line.strip():
continue
data = json.loads(line)
cases.append(ValidationCase(
case_id=data.get("id", f"medqa_{ln}"),
source_dataset="medqa",
input_text=data.get("question", data.get("input", "")),
ground_truth={"answer": data.get("answer", data.get("target", ""))},
metadata=data.get("metadata", {}),
))
return cases
# ──────────────────────────────────────────────
# CLI
# ──────────────────────────────────────────────
async def main():
import argparse
parser = argparse.ArgumentParser(description="Track C: Iterative refinement experiments")
parser.add_argument("--config", type=str, default=None, help="Run a single config by ID")
parser.add_argument("--max-cases", type=int, default=None, help="Limit cases per config")
parser.add_argument("--quiet", action="store_true")
args = parser.parse_args()
logging.basicConfig(
level=logging.WARNING if args.quiet else logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
# Wait for endpoint to be online (handles scale-to-zero)
from tracks.shared.endpoint_check import wait_for_endpoint
if not await wait_for_endpoint(quiet=args.quiet):
print("ABORT: MedGemma endpoint is not reachable. Resume it and try again.")
sys.exit(1)
configs = CONFIGS
if args.config:
configs = [c for c in CONFIGS if c.config_id == args.config]
if not configs:
print(f"Unknown config: {args.config}")
print(f"Available: {[c.config_id for c in CONFIGS]}")
sys.exit(1)
cases = load_medqa_cases(args.max_cases)
if not cases:
print("No MedQA cases loaded")
sys.exit(1)
print(f"Loaded {len(cases)} MedQA cases\n")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
for cfg in configs:
print(f"\n{'='*60}")
print(f" Running config: {cfg.config_id}")
print(f" {cfg.description}")
print(f" Max iterations: {cfg.max_iterations}, convergence: {cfg.convergence_threshold}")
print(f"{'='*60}")
summary = await run_config(cfg, cases)
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
fname = f"trackC_{cfg.config_id}_{ts}.json"
path = RESULTS_DIR / fname
# Use validation save then move
save_path = save_results(summary, filename=fname)
if save_path != path:
import shutil
shutil.move(str(save_path), str(path))
if not args.quiet:
print_summary(summary)
if __name__ == "__main__":
asyncio.run(main())