""" MedQA dataset fetcher and validation harness. Downloads MedQA USMLE 4-option questions and evaluates the CDS pipeline's ability to arrive at the correct diagnosis / answer. Source: https://github.com/jind11/MedQA Format: JSONL with {question, options: {A, B, C, D}, answer_idx, answer} Metrics: - top1_accuracy: Correct answer matches #1 differential diagnosis - top3_accuracy: Correct answer in top 3 differential diagnoses - mentioned_accuracy: Correct answer mentioned anywhere in report - parse_success_rate: Pipeline completed without crashing """ from __future__ import annotations import asyncio import json import logging import random import re import time from pathlib import Path from typing import List, Optional import httpx from validation.base import ( DATA_DIR, ValidationCase, ValidationResult, ValidationSummary, clear_checkpoint, diagnosis_in_differential, ensure_data_dir, fuzzy_match, load_checkpoint, normalize_text, print_summary, run_cds_pipeline, save_incremental, save_results, score_case, ) from validation.question_classifier import ( classify_question, QuestionType, PIPELINE_APPROPRIATE_TYPES, ) from app.services.medgemma import MedGemmaService logger = logging.getLogger(__name__) # ────────────────────────────────────────────── # Data fetching # ────────────────────────────────────────────── # HuggingFace direct download (JSONL) MEDQA_JSONL_URL = "https://huggingface.co/datasets/GBaker/MedQA-USMLE-4-options/resolve/main/phrases_no_exclude_test.jsonl" async def fetch_medqa(max_cases: int = 50, seed: int = 42) -> List[ValidationCase]: """ Download MedQA test set and convert to ValidationCase objects. Args: max_cases: Maximum number of cases to sample seed: Random seed for reproducible sampling """ ensure_data_dir() cache_path = DATA_DIR / "medqa_test.jsonl" # Try to load from cache if cache_path.exists(): print(f" Loading MedQA from cache: {cache_path}") raw_cases = _load_jsonl(cache_path) else: print(f" Downloading MedQA test set...") raw_cases = await _download_medqa_jsonl(cache_path) if not raw_cases: raise RuntimeError("Failed to fetch MedQA data. Check network connection.") # Sample random.seed(seed) if len(raw_cases) > max_cases: raw_cases = random.sample(raw_cases, max_cases) # Convert to ValidationCase cases = [] for i, item in enumerate(raw_cases): question = item.get("question", "") options = item.get("options", item.get("answer_choices", {})) answer_idx = item.get("answer_idx", item.get("answer", "")) answer_text = item.get("answer", "") # Handle different formats if isinstance(options, dict): if answer_idx in options: answer_text = options[answer_idx] elif isinstance(options, list): # Some formats have options as a list idx = ord(answer_idx) - ord('A') if isinstance(answer_idx, str) and len(answer_idx) == 1 else 0 if idx < len(options): answer_text = options[idx] # Split question into vignette + stem (P3: preserve the question stem) vignette, question_stem = _split_question(question) case_obj = ValidationCase( case_id=f"medqa_{i:04d}", source_dataset="medqa", input_text=vignette, ground_truth={ "correct_answer": answer_text, "answer_idx": answer_idx, "options": options, "full_question": question, }, metadata={ "question_stem": question_stem, "clinical_vignette": vignette, "full_question_with_stem": question, }, ) # Classify question type (P1) case_obj.metadata["question_type"] = classify_question(case_obj).value cases.append(case_obj) print(f" Loaded {len(cases)} MedQA cases") return cases async def _download_medqa_jsonl(cache_path: Path) -> List[dict]: """Download MedQA JSONL from GitHub.""" async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client: try: r = await client.get(MEDQA_JSONL_URL) r.raise_for_status() lines = r.text.strip().split('\n') cases = [json.loads(line) for line in lines if line.strip()] # Cache cache_path.write_text('\n'.join(json.dumps(c) for c in cases)) print(f" Cached {len(cases)} MedQA cases to {cache_path}") return cases except Exception as e: print(f" Warning: Failed to download MedQA: {e}") return [] def _load_jsonl(path: Path) -> List[dict]: """Load JSONL file.""" cases = [] for line in path.read_text(encoding="utf-8").strip().split('\n'): if line.strip(): cases.append(json.loads(line)) return cases def _split_question(question: str) -> tuple: """ Split a USMLE question into (clinical_vignette, question_stem). Returns: (vignette, stem) where vignette is the clinical narrative and stem is the trailing question sentence (e.g. "Which of the following..."). If no stem is found, returns (full_question, ""). """ stems = [ r"which of the following", r"what is the most likely", r"what is the best next step", r"what is the most appropriate", r"what is the diagnosis", r"the most likely diagnosis is", r"this patient most likely has", r"what would be the next step", r"what is the next best", r"what is the underlying", r"what is the mechanism", r"which vitamin", r"which enzyme", r"which receptor", r"which drug", ] text = question.strip() for stem in stems: pattern = re.compile( r'\.?\s*([A-Z][^.]*?' + stem + r'[^.]*[\?\.]?)\s*$', re.IGNORECASE, ) match = pattern.search(text) if match: vignette = text[:match.start()].strip() q_stem = match.group(1).strip() if len(vignette) > 50: return vignette, q_stem # Fallback: no stem detected return text, "" # ────────────────────────────────────────────── # MCQ answer selection (P6) # ────────────────────────────────────────────── MCQ_SELECTION_PROMPT = """You are a medical expert answering a multiple-choice question. A clinical decision support system has produced the following analysis of a patient case: === CDS REPORT === {report_summary} === QUESTION === {full_question} === OPTIONS === {options_text} Based on the CDS analysis and your medical knowledge, select the single best answer. Respond with ONLY the letter (A, B, C, or D) on the first line, then a one-sentence justification. """ async def select_mcq_answer( case: ValidationCase, report, state=None, ) -> tuple: """ Use MedGemma to select an MCQ answer given CDS report context. Returns: (selected_letter, justification) e.g. ("B", "The patient's symptoms...") """ # Build report summary parts = [] if report.patient_summary: parts.append(f"Patient Summary: {report.patient_summary}") if report.differential_diagnosis: dx_list = ", ".join(d.diagnosis for d in report.differential_diagnosis[:5]) parts.append(f"Differential Diagnosis: {dx_list}") if report.suggested_next_steps: steps = ", ".join(a.action for a in report.suggested_next_steps[:5]) parts.append(f"Suggested Next Steps: {steps}") if report.guideline_recommendations: recs = ", ".join(report.guideline_recommendations[:3]) parts.append(f"Guideline Recommendations: {recs}") report_summary = "\n".join(parts) if parts else "No report available." # Build options text options = case.ground_truth.get("options", {}) if isinstance(options, dict): options_text = "\n".join(f"{k}. {v}" for k, v in sorted(options.items())) elif isinstance(options, list): options_text = "\n".join( f"{chr(65+i)}. {opt}" for i, opt in enumerate(options) ) else: options_text = str(options) full_question = case.ground_truth.get("full_question", case.input_text) prompt = MCQ_SELECTION_PROMPT.format( report_summary=report_summary, full_question=full_question, options_text=options_text, ) service = MedGemmaService() response = await service.generate( prompt=prompt, max_tokens=100, temperature=0.1, ) # Parse response: first line should be the letter lines = response.strip().split("\n") selected = lines[0].strip().rstrip(".").upper() # Extract just the letter if response is longer for char in selected: if char in "ABCDEFGH": selected = char break else: selected = "X" # Unparseable justification = " ".join(lines[1:]).strip() if len(lines) > 1 else "" return selected, justification # ────────────────────────────────────────────── # Validation harness # ────────────────────────────────────────────── async def validate_medqa( cases: List[ValidationCase], include_drug_check: bool = False, include_guidelines: bool = True, include_mcq: bool = True, delay_between_cases: float = 2.0, resume: bool = False, ) -> ValidationSummary: """ Run MedQA cases through the CDS pipeline and score results. Args: cases: List of MedQA ValidationCases include_drug_check: Whether to run drug interaction check (slower) include_guidelines: Whether to include guideline retrieval include_mcq: Whether to run MCQ answer selection step (adds 1 LLM call/case) delay_between_cases: Seconds to wait between cases (rate limiting) resume: If True, skip cases already in checkpoint and continue """ results: List[ValidationResult] = [] start_time = time.time() # Resume support: load completed cases from checkpoint completed_ids: set = set() if resume: prior = load_checkpoint("medqa") if prior: results.extend(prior) completed_ids = {r.case_id for r in prior} print(f" Resuming: {len(prior)} cases loaded from checkpoint, {len(cases) - len(completed_ids)} remaining") else: clear_checkpoint("medqa") for i, case in enumerate(cases): if case.case_id in completed_ids: print(f"\n [{i+1}/{len(cases)}] {case.case_id}: (cached) skipped") continue print(f"\n [{i+1}/{len(cases)}] {case.case_id}: ", end="", flush=True) case_start = time.monotonic() state, report, error = await run_cds_pipeline( patient_text=case.input_text, include_drug_check=include_drug_check, include_guidelines=include_guidelines, ) elapsed_ms = int((time.monotonic() - case_start) * 1000) # Build step results step_results = {} if state: step_results = {s.step_id: s.status.value for s in state.steps} # Score (type-aware: P4) scores = {} details = {} correct_answer = case.ground_truth["correct_answer"] question_type = case.metadata.get("question_type", "other") if report: # Type-aware scoring scores = score_case( target_answer=correct_answer, report=report, question_type=question_type, reasoning_result=state.clinical_reasoning if state else None, ) scores["parse_success"] = 1.0 # Extract non-float detail fields from scores dict match_location = scores.pop("match_location", "not_found") match_rank = scores.pop("match_rank", -1) # MCQ answer selection (P6) if include_mcq and case.ground_truth.get("options"): try: selected, justification = await select_mcq_answer(case, report, state) mcq_correct_idx = case.ground_truth.get("answer_idx", "") scores["mcq_accuracy"] = 1.0 if selected.upper() == mcq_correct_idx.upper() else 0.0 details["mcq_selected"] = selected details["mcq_justification"] = justification details["mcq_correct"] = mcq_correct_idx except Exception as e: logger.warning(f"MCQ selection failed for {case.case_id}: {e}") scores["mcq_accuracy"] = 0.0 details["mcq_error"] = str(e) # Rich details for debugging all_dx = [dx.diagnosis for dx in report.differential_diagnosis] all_next = [a.action for a in report.suggested_next_steps] all_recs = list(report.guideline_recommendations) details.update({ "correct_answer": correct_answer, "question_type": question_type, "top_diagnosis": all_dx[0] if all_dx else "NONE", "all_diagnoses": all_dx, "all_next_steps": all_next[:5], "all_recommendations": all_recs[:5], "num_diagnoses": len(report.differential_diagnosis), "match_location": match_location, "match_rank": match_rank, "patient_summary": report.patient_summary[:300] if report.patient_summary else "", }) # Console output mentioned = scores.get("mentioned_accuracy", 0.0) > 0 mcq_tag = "" if "mcq_accuracy" in scores: mcq_tag = f" mcq={'Y' if scores['mcq_accuracy'] > 0 else 'N'}" loc_tag = f"[{match_location}]" if mentioned else "" status_icon = "+" if mentioned else "-" print(f"{status_icon} [{question_type}] top1={'Y' if scores.get('top1_accuracy', 0) > 0 else 'N'} mentioned={'Y' if mentioned else 'N'}{mcq_tag} {loc_tag} ({elapsed_ms}ms)") else: scores = { "top1_accuracy": 0.0, "top3_accuracy": 0.0, "mentioned_accuracy": 0.0, "differential_accuracy": 0.0, "parse_success": 0.0, } details = { "correct_answer": correct_answer, "question_type": question_type, "error": error, "match_location": "not_found", } print(f"- FAILED: {error[:80] if error else 'unknown'}") result = ValidationResult( case_id=case.case_id, source_dataset="medqa", success=report is not None, scores=scores, pipeline_time_ms=elapsed_ms, step_results=step_results, report_summary=report.patient_summary[:200] if report else None, error=error, details=details, ) results.append(result) save_incremental(result, "medqa") # checkpoint after every case # Rate limit if i < len(cases) - 1: await asyncio.sleep(delay_between_cases) # Aggregate total = len(results) successful = sum(1 for r in results if r.success) # Overall metrics metric_names = [ "top1_accuracy", "top3_accuracy", "mentioned_accuracy", "differential_accuracy", "parse_success", "mcq_accuracy", ] metrics = {} for m in metric_names: values = [r.scores.get(m, 0.0) for r in results if m in r.scores] metrics[m] = sum(values) / len(values) if values else 0.0 # Average pipeline time times = [r.pipeline_time_ms for r in results if r.success] metrics["avg_pipeline_time_ms"] = sum(times) / len(times) if times else 0 # Stratified metrics by question type (P7) by_type: dict = {} for r in results: qt = r.details.get("question_type", "other") by_type.setdefault(qt, []).append(r) for qt, type_results in by_type.items(): n = len(type_results) metrics[f"count_{qt}"] = n for m in ["top1_accuracy", "top3_accuracy", "mentioned_accuracy", "mcq_accuracy"]: values = [r.scores.get(m, 0.0) for r in type_results if m in r.scores] if values: metrics[f"{m}_{qt}"] = sum(values) / len(values) # Pipeline-appropriate subset (diagnostic + treatment + lab_finding) appropriate_types = {t.value for t in PIPELINE_APPROPRIATE_TYPES} appropriate_results = [ r for r in results if r.details.get("question_type", "other") in appropriate_types ] if appropriate_results: for m in ["top1_accuracy", "top3_accuracy", "mentioned_accuracy"]: values = [r.scores.get(m, 0.0) for r in appropriate_results] metrics[f"{m}_pipeline_appropriate"] = sum(values) / len(values) if values else 0.0 metrics["count_pipeline_appropriate"] = len(appropriate_results) summary = ValidationSummary( dataset="medqa", total_cases=total, successful_cases=successful, failed_cases=total - successful, metrics=metrics, per_case=results, run_duration_sec=time.time() - start_time, ) return summary # ────────────────────────────────────────────── # Standalone runner # ────────────────────────────────────────────── async def main(): """Run MedQA validation standalone.""" import argparse parser = argparse.ArgumentParser(description="MedQA Validation") parser.add_argument("--max-cases", type=int, default=10, help="Number of cases to evaluate") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--include-drugs", action="store_true", help="Include drug interaction check") parser.add_argument("--no-mcq", action="store_true", help="Disable MCQ answer selection step") parser.add_argument("--delay", type=float, default=2.0, help="Delay between cases (seconds)") args = parser.parse_args() print("MedQA Validation Harness") print("=" * 40) cases = await fetch_medqa(max_cases=args.max_cases, seed=args.seed) summary = await validate_medqa( cases, include_drug_check=args.include_drugs, include_mcq=not args.no_mcq, delay_between_cases=args.delay, ) print_summary(summary) path = save_results(summary) print(f"Results saved to: {path}") if __name__ == "__main__": asyncio.run(main())