File size: 11,216 Bytes
28f1212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# [Track C: Iterative Refinement]
"""
Track C — Run iterative refinement experiments against MedQA.

Usage:
    cd src/backend
    python -m tracks.iterative.run_iterative                    # all configs
    python -m tracks.iterative.run_iterative --config C1_3rounds  # single config
    python -m tracks.iterative.run_iterative --max-cases 10       # quick test

Each config runs the full baseline pipeline first, then feeds the initial
reasoning through N self-critique iterations. Results include per-iteration
accuracy AND cost, enabling cost/benefit charts.
"""
from __future__ import annotations

import asyncio
import json
import logging
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Optional

BACKEND_DIR = Path(__file__).resolve().parent.parent.parent
if str(BACKEND_DIR) not in sys.path:
    sys.path.insert(0, str(BACKEND_DIR))

from app.agent.orchestrator import Orchestrator
from app.models.schemas import (
    CaseSubmission,
    CDSReport,
    AgentStepStatus,
    ClinicalReasoningResult,
)
from tracks.iterative.config import CONFIGS, IterativeConfig
from tracks.iterative.refiner import IterativeRefiner
from tracks.shared.cost_tracker import CostLedger
from validation.base import (
    ValidationCase,
    ValidationResult,
    ValidationSummary,
    diagnosis_in_differential,
    run_cds_pipeline,
    save_results,
    print_summary,
)

logger = logging.getLogger(__name__)

RESULTS_DIR = Path(__file__).resolve().parent / "results"
MEDQA_PATH = BACKEND_DIR / "validation" / "data" / "medqa_test.jsonl"


# ──────────────────────────────────────────────
# Per-case runner
# ──────────────────────────────────────────────

async def run_case_iterative(
    case: ValidationCase,
    config: IterativeConfig,
    ledger: CostLedger,
) -> ValidationResult:
    """
    Run one case through:
      1. The baseline pipeline (Track A) to get initial reasoning
      2. The iterative refinement loop (Track C)
      3. Re-synthesize the final report with the refined differential
    """
    t0 = time.monotonic()

    # ── Step 1: Run baseline pipeline ──
    state, report, error = await run_cds_pipeline(
        patient_text=case.input_text,
        include_drug_check=True,
        include_guidelines=True,
    )

    if error or not state or not state.clinical_reasoning or not state.patient_profile:
        return ValidationResult(
            case_id=case.case_id,
            source_dataset=f"trackC_{config.config_id}",
            success=False,
            scores={},
            pipeline_time_ms=int((time.monotonic() - t0) * 1000),
            error=error or "Baseline pipeline failed to produce reasoning",
        )

    # ── Step 2: Iterative refinement ──
    refiner = IterativeRefiner(config, ledger)
    refined_reasoning, history = await refiner.refine(
        profile=state.patient_profile,
        initial_reasoning=state.clinical_reasoning,
    )

    # ── Step 3: Re-synthesize with the refined differential ──
    # Inject the refined reasoning back into the orchestrator state and
    # re-run just the synthesis step
    from app.tools.synthesis import SynthesisTool
    synth = SynthesisTool()
    try:
        refined_report = await synth.run(
            patient_profile=state.patient_profile,
            clinical_reasoning=refined_reasoning,
            drug_interactions=state.drug_interactions,
            guideline_retrieval=state.guideline_retrieval,
            conflict_detection=state.conflict_detection,
        )
    except Exception as e:
        refined_report = report  # Fall back to baseline report
        logger.warning(f"Re-synthesis failed, using baseline report: {e}")

    elapsed_ms = int((time.monotonic() - t0) * 1000)

    # ── Score — compare baseline vs. refined ──
    scores: dict = {}
    details: dict = {"iterations": len(history) - 1}  # subtract the initial

    if "answer" in case.ground_truth:
        gt = case.ground_truth["answer"]

        # Score the baseline
        if report:
            b_found, b_rank, b_loc = diagnosis_in_differential(gt, report)
            scores["baseline_top1"] = 1.0 if (b_found and b_rank == 0) else 0.0
            scores["baseline_mentioned"] = 1.0 if b_found else 0.0

        # Score the refined report
        target_report = refined_report or report
        if target_report:
            r_found, r_rank, r_loc = diagnosis_in_differential(gt, target_report)
            scores["top1_accuracy"] = 1.0 if (r_found and r_rank == 0) else 0.0
            scores["top3_accuracy"] = 1.0 if (r_found and r_rank < 3) else 0.0
            scores["mentioned"] = 1.0 if r_found else 0.0
            details["rank"] = r_rank
            details["match_location"] = r_loc
            details["improved"] = scores.get("top1_accuracy", 0) > scores.get("baseline_top1", 0)

    # Per-iteration differential snapshots (for cost/benefit charts)
    details["per_iteration_top_dx"] = [
        h.differential_diagnosis[0].diagnosis if h.differential_diagnosis else "?"
        for h in history
    ]
    details["cost_ledger"] = ledger.to_dict()

    return ValidationResult(
        case_id=case.case_id,
        source_dataset=f"trackC_{config.config_id}",
        success=True,
        scores=scores,
        pipeline_time_ms=elapsed_ms,
        report_summary=(refined_report or report).patient_summary[:200] if (refined_report or report) else None,
        details=details,
    )


# ──────────────────────────────────────────────
# Experiment runner
# ──────────────────────────────────────────────

async def run_config(
    config: IterativeConfig,
    cases: List[ValidationCase],
) -> ValidationSummary:
    """Run all cases through the iterative config."""
    results: List[ValidationResult] = []
    start = time.monotonic()

    for i, case in enumerate(cases, 1):
        logger.info(f"  [{config.config_id}] case {i}/{len(cases)}: {case.case_id}")
        ledger = CostLedger(track_id=f"C_{config.config_id}")
        vr = await run_case_iterative(case, config, ledger)
        results.append(vr)

    elapsed = time.monotonic() - start
    successful = [r for r in results if r.success]

    metrics = {}
    for key in ("top1_accuracy", "top3_accuracy", "mentioned", "baseline_top1", "baseline_mentioned"):
        vals = [r.scores[key] for r in successful if key in r.scores]
        metrics[key] = sum(vals) / len(vals) if vals else 0.0
    metrics["pipeline_success"] = len(successful) / len(results) if results else 0.0

    # Average iterations used
    iters = [r.details.get("iterations", 0) for r in successful]
    metrics["avg_iterations"] = sum(iters) / len(iters) if iters else 0.0

    # Improvement rate
    improved = [r for r in successful if r.details.get("improved")]
    metrics["improvement_rate"] = len(improved) / len(successful) if successful else 0.0

    return ValidationSummary(
        dataset=f"trackC_{config.config_id}",
        total_cases=len(results),
        successful_cases=len(successful),
        failed_cases=len(results) - len(successful),
        metrics=metrics,
        per_case=results,
        run_duration_sec=round(elapsed, 1),
    )


# ──────────────────────────────────────────────
# Data loading (reuse from validation)
# ──────────────────────────────────────────────

def load_medqa_cases(max_cases: Optional[int] = None) -> List[ValidationCase]:
    if not MEDQA_PATH.exists():
        logger.error(f"MedQA data not found at {MEDQA_PATH}")
        return []
    cases: List[ValidationCase] = []
    with open(MEDQA_PATH, "r", encoding="utf-8") as f:
        for ln, line in enumerate(f, 1):
            if max_cases and len(cases) >= max_cases:
                break
            if not line.strip():
                continue
            data = json.loads(line)
            cases.append(ValidationCase(
                case_id=data.get("id", f"medqa_{ln}"),
                source_dataset="medqa",
                input_text=data.get("question", data.get("input", "")),
                ground_truth={"answer": data.get("answer", data.get("target", ""))},
                metadata=data.get("metadata", {}),
            ))
    return cases


# ──────────────────────────────────────────────
# CLI
# ──────────────────────────────────────────────

async def main():
    import argparse
    parser = argparse.ArgumentParser(description="Track C: Iterative refinement experiments")
    parser.add_argument("--config", type=str, default=None, help="Run a single config by ID")
    parser.add_argument("--max-cases", type=int, default=None, help="Limit cases per config")
    parser.add_argument("--quiet", action="store_true")
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.WARNING if args.quiet else logging.INFO,
        format="%(asctime)s %(levelname)s %(name)s: %(message)s",
    )

    # Wait for endpoint to be online (handles scale-to-zero)
    from tracks.shared.endpoint_check import wait_for_endpoint
    if not await wait_for_endpoint(quiet=args.quiet):
        print("ABORT: MedGemma endpoint is not reachable. Resume it and try again.")
        sys.exit(1)

    configs = CONFIGS
    if args.config:
        configs = [c for c in CONFIGS if c.config_id == args.config]
        if not configs:
            print(f"Unknown config: {args.config}")
            print(f"Available: {[c.config_id for c in CONFIGS]}")
            sys.exit(1)

    cases = load_medqa_cases(args.max_cases)
    if not cases:
        print("No MedQA cases loaded")
        sys.exit(1)
    print(f"Loaded {len(cases)} MedQA cases\n")

    RESULTS_DIR.mkdir(parents=True, exist_ok=True)

    for cfg in configs:
        print(f"\n{'='*60}")
        print(f"  Running config: {cfg.config_id}")
        print(f"  {cfg.description}")
        print(f"  Max iterations: {cfg.max_iterations}, convergence: {cfg.convergence_threshold}")
        print(f"{'='*60}")

        summary = await run_config(cfg, cases)

        ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
        fname = f"trackC_{cfg.config_id}_{ts}.json"
        path = RESULTS_DIR / fname
        # Use validation save then move
        save_path = save_results(summary, filename=fname)
        if save_path != path:
            import shutil
            shutil.move(str(save_path), str(path))

        if not args.quiet:
            print_summary(summary)


if __name__ == "__main__":
    asyncio.run(main())