bshepp commited on
Commit
393ff7f
Β·
1 Parent(s): c28dd56

Add validation framework for MedQA, MTSamples, and PMC Case Reports

Browse files

Three-dataset validation suite that evaluates the CDS pipeline against
external clinical benchmarks:

- MedQA: 1273 USMLE-style questions, measures diagnostic accuracy
(top-1, top-3, mentioned-anywhere)
- MTSamples: ~5000 medical transcriptions, measures parse robustness
and field extraction completeness across specialties
- PMC Case Reports: Published case reports from PubMed, measures
real-world diagnostic accuracy against gold-standard diagnoses

Architecture:
- validation/base.py: Core framework (runner, fuzzy matching, scoring)
- validation/harness_medqa.py: MedQA fetcher + harness
- validation/harness_mtsamples.py: MTSamples fetcher + harness
- validation/harness_pmc.py: PMC Case Reports fetcher + harness
- validation/run_validation.py: Unified CLI runner

Uses direct Orchestrator calls (no server needed). Tested end-to-end:
3 MedQA cases, 66.7% top-1 accuracy, 100% parse success rate."

.gitignore CHANGED
@@ -39,6 +39,10 @@ out/
39
  # Test outputs
40
  results.json
41
 
 
 
 
 
42
  # Models (too large for git)
43
  models/*.bin
44
  models/*.pt
 
39
  # Test outputs
40
  results.json
41
 
42
+ # Validation datasets (downloaded on demand) and results
43
+ src/backend/validation/data/
44
+ src/backend/validation/results/
45
+
46
  # Models (too large for git)
47
  models/*.bin
48
  models/*.pt
src/backend/validation/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Validation framework for the Clinical Decision Support Agent.
3
+
4
+ Validates the CDS pipeline against three external clinical datasets:
5
+ - MedQA (USMLE-style questions) β€” diagnostic accuracy
6
+ - MTSamples (medical transcriptions) β€” parse robustness
7
+ - PMC Case Reports (published cases) β€” real-world diagnostic accuracy
8
+ """
src/backend/validation/base.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base classes and utilities for the validation framework.
3
+
4
+ Provides:
5
+ - ValidationCase: a single test case with input + ground truth
6
+ - ValidationResult: scored result for a single case
7
+ - ValidationSummary: aggregate metrics for a dataset
8
+ - run_cds_pipeline(): runs a case through the orchestrator directly
9
+ - fuzzy_match(): soft string matching for diagnosis comparison
10
+ """
11
+ from __future__ import annotations
12
+
13
+ import asyncio
14
+ import json
15
+ import re
16
+ import time
17
+ from dataclasses import dataclass, field
18
+ from datetime import datetime, timezone
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Optional
21
+
22
+ # ── CDS pipeline imports ──
23
+ import sys
24
+
25
+ # Ensure the backend app is importable
26
+ BACKEND_DIR = Path(__file__).resolve().parent.parent
27
+ if str(BACKEND_DIR) not in sys.path:
28
+ sys.path.insert(0, str(BACKEND_DIR))
29
+
30
+ from app.agent.orchestrator import Orchestrator
31
+ from app.models.schemas import CaseSubmission, CDSReport, AgentState
32
+
33
+
34
+ # ──────────────────────────────────────────────
35
+ # Data classes
36
+ # ──────────────────────────────────────────────
37
+
38
+ @dataclass
39
+ class ValidationCase:
40
+ """A single validation test case."""
41
+ case_id: str
42
+ source_dataset: str # "medqa", "mtsamples", "pmc"
43
+ input_text: str # Clinical text fed to the pipeline
44
+ ground_truth: Dict[str, Any] # Dataset-specific ground truth
45
+ metadata: Dict[str, Any] = field(default_factory=dict)
46
+
47
+
48
+ @dataclass
49
+ class ValidationResult:
50
+ """Result of running one case through the pipeline + scoring."""
51
+ case_id: str
52
+ source_dataset: str
53
+ success: bool # Pipeline completed without crash
54
+ scores: Dict[str, float] # Metric name β†’ score (0.0–1.0)
55
+ pipeline_time_ms: int = 0
56
+ step_results: Dict[str, str] = field(default_factory=dict) # step_id β†’ status
57
+ report_summary: Optional[str] = None
58
+ error: Optional[str] = None
59
+ details: Dict[str, Any] = field(default_factory=dict) # Extra scoring info
60
+
61
+
62
+ @dataclass
63
+ class ValidationSummary:
64
+ """Aggregate metrics for a dataset validation run."""
65
+ dataset: str
66
+ total_cases: int
67
+ successful_cases: int
68
+ failed_cases: int
69
+ metrics: Dict[str, float] # Metric name β†’ average score
70
+ per_case: List[ValidationResult]
71
+ run_duration_sec: float
72
+ timestamp: str = ""
73
+
74
+ def __post_init__(self):
75
+ if not self.timestamp:
76
+ self.timestamp = datetime.now(timezone.utc).isoformat()
77
+
78
+
79
+ # ──────────────────────────────────────────────
80
+ # Pipeline runner
81
+ # ──────────────────────────────────────────────
82
+
83
+ async def run_cds_pipeline(
84
+ patient_text: str,
85
+ include_drug_check: bool = True,
86
+ include_guidelines: bool = True,
87
+ timeout_sec: int = 180,
88
+ ) -> tuple[Optional[AgentState], Optional[CDSReport], Optional[str]]:
89
+ """
90
+ Run a single case through the CDS pipeline directly (no HTTP server needed).
91
+
92
+ Returns:
93
+ (state, report, error) β€” error is None on success
94
+ """
95
+ case = CaseSubmission(
96
+ patient_text=patient_text,
97
+ include_drug_check=include_drug_check,
98
+ include_guidelines=include_guidelines,
99
+ )
100
+ orchestrator = Orchestrator()
101
+
102
+ try:
103
+ async for _step_update in orchestrator.run(case):
104
+ pass # consume all step updates
105
+
106
+ return orchestrator.state, orchestrator.get_result(), None
107
+ except asyncio.TimeoutError:
108
+ return orchestrator.state, None, f"Pipeline timed out after {timeout_sec}s"
109
+ except Exception as e:
110
+ return orchestrator.state, None, str(e)
111
+
112
+
113
+ # ──────────────────────────────────────────────
114
+ # Fuzzy string matching for diagnosis comparison
115
+ # ──────────────────────────────────────────────
116
+
117
+ def normalize_text(text: str) -> str:
118
+ """Lowercase, strip punctuation, normalize whitespace."""
119
+ text = text.lower().strip()
120
+ text = re.sub(r'[^\w\s]', ' ', text)
121
+ text = re.sub(r'\s+', ' ', text)
122
+ return text
123
+
124
+
125
+ def fuzzy_match(candidate: str, target: str, threshold: float = 0.6) -> bool:
126
+ """
127
+ Check if candidate text is a fuzzy match for target.
128
+
129
+ Uses token overlap (Jaccard-like) rather than edit distance β€”
130
+ medical terms are long and we care about semantic overlap, not typos.
131
+
132
+ Args:
133
+ candidate: Text from the pipeline output
134
+ target: Ground truth text
135
+ threshold: Minimum token overlap ratio (0.0–1.0)
136
+ """
137
+ c_tokens = set(normalize_text(candidate).split())
138
+ t_tokens = set(normalize_text(target).split())
139
+
140
+ if not t_tokens:
141
+ return False
142
+
143
+ # If target is a substring of candidate (or vice versa), that's a match
144
+ if normalize_text(target) in normalize_text(candidate):
145
+ return True
146
+ if normalize_text(candidate) in normalize_text(target):
147
+ return True
148
+
149
+ # Token overlap
150
+ overlap = len(c_tokens & t_tokens)
151
+ denominator = min(len(c_tokens), len(t_tokens))
152
+ if denominator == 0:
153
+ return False
154
+
155
+ return (overlap / denominator) >= threshold
156
+
157
+
158
+ def diagnosis_in_differential(
159
+ target_diagnosis: str,
160
+ report: CDSReport,
161
+ top_n: Optional[int] = None,
162
+ ) -> tuple[bool, int]:
163
+ """
164
+ Check if target_diagnosis appears in the report's differential.
165
+
166
+ Returns:
167
+ (found, rank) β€” rank is 0-indexed position, or -1 if not found
168
+ """
169
+ diagnoses = report.differential_diagnosis
170
+ if top_n:
171
+ diagnoses = diagnoses[:top_n]
172
+
173
+ for i, dx in enumerate(diagnoses):
174
+ if fuzzy_match(dx.diagnosis, target_diagnosis):
175
+ return True, i
176
+
177
+ # Also check the full report text (patient_summary, guideline_recommendations, etc.)
178
+ full_text = " ".join([
179
+ report.patient_summary or "",
180
+ " ".join(report.guideline_recommendations),
181
+ " ".join(a.action for a in report.suggested_next_steps),
182
+ ])
183
+ if fuzzy_match(full_text, target_diagnosis, threshold=0.3):
184
+ return True, len(diagnoses) # found but not in differential
185
+
186
+ return False, -1
187
+
188
+
189
+ # ──────────────────────────────────────────────
190
+ # I/O utilities
191
+ # ──────────────────────────────────────────────
192
+
193
+ DATA_DIR = Path(__file__).resolve().parent / "data"
194
+
195
+
196
+ def ensure_data_dir():
197
+ """Create the data directory if it doesn't exist."""
198
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
199
+
200
+
201
+ def save_results(summary: ValidationSummary, filename: str = None):
202
+ """Save validation results to JSON."""
203
+ results_dir = Path(__file__).resolve().parent / "results"
204
+ results_dir.mkdir(parents=True, exist_ok=True)
205
+
206
+ if filename is None:
207
+ ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
208
+ filename = f"{summary.dataset}_{ts}.json"
209
+
210
+ path = results_dir / filename
211
+
212
+ # Convert to serializable dict
213
+ data = {
214
+ "dataset": summary.dataset,
215
+ "total_cases": summary.total_cases,
216
+ "successful_cases": summary.successful_cases,
217
+ "failed_cases": summary.failed_cases,
218
+ "metrics": summary.metrics,
219
+ "run_duration_sec": summary.run_duration_sec,
220
+ "timestamp": summary.timestamp,
221
+ "per_case": [
222
+ {
223
+ "case_id": r.case_id,
224
+ "success": r.success,
225
+ "scores": r.scores,
226
+ "pipeline_time_ms": r.pipeline_time_ms,
227
+ "step_results": r.step_results,
228
+ "report_summary": r.report_summary,
229
+ "error": r.error,
230
+ "details": r.details,
231
+ }
232
+ for r in summary.per_case
233
+ ],
234
+ }
235
+
236
+ path.write_text(json.dumps(data, indent=2, default=str))
237
+ return path
238
+
239
+
240
+ def print_summary(summary: ValidationSummary):
241
+ """Pretty-print validation results to console."""
242
+ print(f"\n{'='*60}")
243
+ print(f" Validation Results: {summary.dataset.upper()}")
244
+ print(f"{'='*60}")
245
+ print(f" Total cases: {summary.total_cases}")
246
+ print(f" Successful: {summary.successful_cases}")
247
+ print(f" Failed: {summary.failed_cases}")
248
+ print(f" Duration: {summary.run_duration_sec:.1f}s")
249
+ print(f"\n Metrics:")
250
+ for metric, value in sorted(summary.metrics.items()):
251
+ if "time" in metric and isinstance(value, (int, float)):
252
+ print(f" {metric:30s} {value:.0f}ms")
253
+ elif isinstance(value, float):
254
+ print(f" {metric:30s} {value:.1%}")
255
+ else:
256
+ print(f" {metric:30s} {value}")
257
+ print(f"{'='*60}\n")
src/backend/validation/harness_medqa.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MedQA dataset fetcher and validation harness.
3
+
4
+ Downloads MedQA USMLE 4-option questions and evaluates the CDS pipeline's
5
+ ability to arrive at the correct diagnosis / answer.
6
+
7
+ Source: https://github.com/jind11/MedQA
8
+ Format: JSONL with {question, options: {A, B, C, D}, answer_idx, answer}
9
+
10
+ Metrics:
11
+ - top1_accuracy: Correct answer matches #1 differential diagnosis
12
+ - top3_accuracy: Correct answer in top 3 differential diagnoses
13
+ - mentioned_accuracy: Correct answer mentioned anywhere in report
14
+ - parse_success_rate: Pipeline completed without crashing
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import json
20
+ import random
21
+ import time
22
+ from pathlib import Path
23
+ from typing import List, Optional
24
+
25
+ import httpx
26
+
27
+ from validation.base import (
28
+ DATA_DIR,
29
+ ValidationCase,
30
+ ValidationResult,
31
+ ValidationSummary,
32
+ diagnosis_in_differential,
33
+ ensure_data_dir,
34
+ fuzzy_match,
35
+ normalize_text,
36
+ print_summary,
37
+ run_cds_pipeline,
38
+ save_results,
39
+ )
40
+
41
+
42
+ # ──────────────────────────────────────────────
43
+ # Data fetching
44
+ # ──────────────────────────────────────────────
45
+
46
+ # HuggingFace direct download (JSONL)
47
+ MEDQA_JSONL_URL = "https://huggingface.co/datasets/GBaker/MedQA-USMLE-4-options/resolve/main/phrases_no_exclude_test.jsonl"
48
+
49
+
50
+ async def fetch_medqa(max_cases: int = 50, seed: int = 42) -> List[ValidationCase]:
51
+ """
52
+ Download MedQA test set and convert to ValidationCase objects.
53
+
54
+ Args:
55
+ max_cases: Maximum number of cases to sample
56
+ seed: Random seed for reproducible sampling
57
+ """
58
+ ensure_data_dir()
59
+ cache_path = DATA_DIR / "medqa_test.jsonl"
60
+
61
+ # Try to load from cache
62
+ if cache_path.exists():
63
+ print(f" Loading MedQA from cache: {cache_path}")
64
+ raw_cases = _load_jsonl(cache_path)
65
+ else:
66
+ print(f" Downloading MedQA test set...")
67
+ raw_cases = await _download_medqa_jsonl(cache_path)
68
+
69
+ if not raw_cases:
70
+ raise RuntimeError("Failed to fetch MedQA data. Check network connection.")
71
+
72
+ # Sample
73
+ random.seed(seed)
74
+ if len(raw_cases) > max_cases:
75
+ raw_cases = random.sample(raw_cases, max_cases)
76
+
77
+ # Convert to ValidationCase
78
+ cases = []
79
+ for i, item in enumerate(raw_cases):
80
+ question = item.get("question", "")
81
+ options = item.get("options", item.get("answer_choices", {}))
82
+ answer_idx = item.get("answer_idx", item.get("answer", ""))
83
+ answer_text = item.get("answer", "")
84
+
85
+ # Handle different formats
86
+ if isinstance(options, dict):
87
+ if answer_idx in options:
88
+ answer_text = options[answer_idx]
89
+ elif isinstance(options, list):
90
+ # Some formats have options as a list
91
+ idx = ord(answer_idx) - ord('A') if isinstance(answer_idx, str) and len(answer_idx) == 1 else 0
92
+ if idx < len(options):
93
+ answer_text = options[idx]
94
+
95
+ # Build clinical vignette (question only, not the options)
96
+ # This simulates what a clinician would present
97
+ clinical_text = _extract_vignette(question)
98
+
99
+ cases.append(ValidationCase(
100
+ case_id=f"medqa_{i:04d}",
101
+ source_dataset="medqa",
102
+ input_text=clinical_text,
103
+ ground_truth={
104
+ "correct_answer": answer_text,
105
+ "answer_idx": answer_idx,
106
+ "options": options,
107
+ "full_question": question,
108
+ },
109
+ ))
110
+
111
+ print(f" Loaded {len(cases)} MedQA cases")
112
+ return cases
113
+
114
+
115
+ async def _download_medqa_jsonl(cache_path: Path) -> List[dict]:
116
+ """Download MedQA JSONL from GitHub."""
117
+ async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
118
+ try:
119
+ r = await client.get(MEDQA_JSONL_URL)
120
+ r.raise_for_status()
121
+
122
+ lines = r.text.strip().split('\n')
123
+ cases = [json.loads(line) for line in lines if line.strip()]
124
+
125
+ # Cache
126
+ cache_path.write_text('\n'.join(json.dumps(c) for c in cases))
127
+ print(f" Cached {len(cases)} MedQA cases to {cache_path}")
128
+ return cases
129
+
130
+ except Exception as e:
131
+ print(f" Warning: Failed to download MedQA: {e}")
132
+ return []
133
+
134
+
135
+ def _load_jsonl(path: Path) -> List[dict]:
136
+ """Load JSONL file."""
137
+ cases = []
138
+ for line in path.read_text(encoding="utf-8").strip().split('\n'):
139
+ if line.strip():
140
+ cases.append(json.loads(line))
141
+ return cases
142
+
143
+
144
+ def _extract_vignette(question: str) -> str:
145
+ """
146
+ Extract the clinical vignette from a USMLE question.
147
+
148
+ USMLE questions typically end with "Which of the following..." or
149
+ "What is the most likely diagnosis?". We strip the question stem
150
+ to leave just the clinical narrative.
151
+ """
152
+ # Common question stems
153
+ stems = [
154
+ r"which of the following",
155
+ r"what is the most likely",
156
+ r"what is the best next step",
157
+ r"what is the most appropriate",
158
+ r"what is the diagnosis",
159
+ r"the most likely diagnosis is",
160
+ r"this patient most likely has",
161
+ r"what would be the next step",
162
+ ]
163
+
164
+ text = question.strip()
165
+ for stem in stems:
166
+ import re
167
+ # Find the last sentence that starts a question
168
+ pattern = re.compile(rf'\.?\s*[A-Z].*{stem}.*[\?\.]?\s*$', re.IGNORECASE)
169
+ match = pattern.search(text)
170
+ if match:
171
+ # Return everything before the question stem sentence
172
+ vignette = text[:match.start()].strip()
173
+ if len(vignette) > 50: # Sanity check
174
+ return vignette
175
+
176
+ # Fallback: return the full text
177
+ return text
178
+
179
+
180
+ # ──────────────────────────────────────────────
181
+ # Validation harness
182
+ # ──────────────────────────────────────────────
183
+
184
+ async def validate_medqa(
185
+ cases: List[ValidationCase],
186
+ include_drug_check: bool = False,
187
+ include_guidelines: bool = True,
188
+ delay_between_cases: float = 2.0,
189
+ ) -> ValidationSummary:
190
+ """
191
+ Run MedQA cases through the CDS pipeline and score results.
192
+
193
+ Args:
194
+ cases: List of MedQA ValidationCases
195
+ include_drug_check: Whether to run drug interaction check (slower)
196
+ include_guidelines: Whether to include guideline retrieval
197
+ delay_between_cases: Seconds to wait between cases (rate limiting)
198
+ """
199
+ results: List[ValidationResult] = []
200
+ start_time = time.time()
201
+
202
+ for i, case in enumerate(cases):
203
+ print(f"\n [{i+1}/{len(cases)}] {case.case_id}: ", end="", flush=True)
204
+
205
+ case_start = time.monotonic()
206
+
207
+ state, report, error = await run_cds_pipeline(
208
+ patient_text=case.input_text,
209
+ include_drug_check=include_drug_check,
210
+ include_guidelines=include_guidelines,
211
+ )
212
+
213
+ elapsed_ms = int((time.monotonic() - case_start) * 1000)
214
+
215
+ # Build step results
216
+ step_results = {}
217
+ if state:
218
+ step_results = {s.step_id: s.status.value for s in state.steps}
219
+
220
+ # Score
221
+ scores = {}
222
+ details = {}
223
+ correct_answer = case.ground_truth["correct_answer"]
224
+
225
+ if report:
226
+ # Top-1 accuracy
227
+ found_top1, rank = diagnosis_in_differential(correct_answer, report, top_n=1)
228
+ scores["top1_accuracy"] = 1.0 if found_top1 else 0.0
229
+
230
+ # Top-3 accuracy
231
+ found_top3, rank3 = diagnosis_in_differential(correct_answer, report, top_n=3)
232
+ scores["top3_accuracy"] = 1.0 if found_top3 else 0.0
233
+
234
+ # Mentioned anywhere
235
+ found_any, rank_any = diagnosis_in_differential(correct_answer, report)
236
+ scores["mentioned_accuracy"] = 1.0 if found_any else 0.0
237
+
238
+ # Parse success
239
+ scores["parse_success"] = 1.0
240
+
241
+ details = {
242
+ "correct_answer": correct_answer,
243
+ "top_diagnosis": report.differential_diagnosis[0].diagnosis if report.differential_diagnosis else "NONE",
244
+ "num_diagnoses": len(report.differential_diagnosis),
245
+ "found_at_rank": rank_any if found_any else -1,
246
+ }
247
+
248
+ status_icon = "βœ“" if found_top3 else "βœ—"
249
+ print(f"{status_icon} top1={'Y' if found_top1 else 'N'} top3={'Y' if found_top3 else 'N'} ({elapsed_ms}ms)")
250
+ else:
251
+ scores = {
252
+ "top1_accuracy": 0.0,
253
+ "top3_accuracy": 0.0,
254
+ "mentioned_accuracy": 0.0,
255
+ "parse_success": 0.0,
256
+ }
257
+ details = {"correct_answer": correct_answer, "error": error}
258
+ print(f"βœ— FAILED: {error[:80] if error else 'unknown'}")
259
+
260
+ results.append(ValidationResult(
261
+ case_id=case.case_id,
262
+ source_dataset="medqa",
263
+ success=report is not None,
264
+ scores=scores,
265
+ pipeline_time_ms=elapsed_ms,
266
+ step_results=step_results,
267
+ report_summary=report.patient_summary[:200] if report else None,
268
+ error=error,
269
+ details=details,
270
+ ))
271
+
272
+ # Rate limit
273
+ if i < len(cases) - 1:
274
+ await asyncio.sleep(delay_between_cases)
275
+
276
+ # Aggregate
277
+ total = len(results)
278
+ successful = sum(1 for r in results if r.success)
279
+
280
+ # Average each metric across successful cases only
281
+ metric_names = ["top1_accuracy", "top3_accuracy", "mentioned_accuracy", "parse_success"]
282
+ metrics = {}
283
+ for m in metric_names:
284
+ values = [r.scores.get(m, 0.0) for r in results]
285
+ metrics[m] = sum(values) / len(values) if values else 0.0
286
+
287
+ # Average pipeline time
288
+ times = [r.pipeline_time_ms for r in results if r.success]
289
+ metrics["avg_pipeline_time_ms"] = sum(times) / len(times) if times else 0
290
+
291
+ summary = ValidationSummary(
292
+ dataset="medqa",
293
+ total_cases=total,
294
+ successful_cases=successful,
295
+ failed_cases=total - successful,
296
+ metrics=metrics,
297
+ per_case=results,
298
+ run_duration_sec=time.time() - start_time,
299
+ )
300
+
301
+ return summary
302
+
303
+
304
+ # ──────────────────────────────────────────────
305
+ # Standalone runner
306
+ # ──────────────────────────────────────────────
307
+
308
+ async def main():
309
+ """Run MedQA validation standalone."""
310
+ import argparse
311
+
312
+ parser = argparse.ArgumentParser(description="MedQA Validation")
313
+ parser.add_argument("--max-cases", type=int, default=10, help="Number of cases to evaluate")
314
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
315
+ parser.add_argument("--include-drugs", action="store_true", help="Include drug interaction check")
316
+ parser.add_argument("--delay", type=float, default=2.0, help="Delay between cases (seconds)")
317
+ args = parser.parse_args()
318
+
319
+ print("MedQA Validation Harness")
320
+ print("=" * 40)
321
+
322
+ cases = await fetch_medqa(max_cases=args.max_cases, seed=args.seed)
323
+ summary = await validate_medqa(
324
+ cases,
325
+ include_drug_check=args.include_drugs,
326
+ delay_between_cases=args.delay,
327
+ )
328
+
329
+ print_summary(summary)
330
+ path = save_results(summary)
331
+ print(f"Results saved to: {path}")
332
+
333
+
334
+ if __name__ == "__main__":
335
+ asyncio.run(main())
src/backend/validation/harness_mtsamples.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MTSamples dataset fetcher and validation harness.
3
+
4
+ Downloads medical transcription samples and evaluates the CDS pipeline's
5
+ ability to parse diverse clinical note formats and reason across specialties.
6
+
7
+ Source: https://mtsamples.com (via GitHub mirrors)
8
+ Format: CSV with columns: description, medical_specialty, sample_name, transcription, keywords
9
+
10
+ Metrics:
11
+ - parse_success_rate: Pipeline completed without crashing
12
+ - field_completeness: How many structured fields were extracted
13
+ - specialty_alignment: System reasoning aligns with correct specialty
14
+ - has_differential: Report includes at least one diagnosis
15
+ - has_recommendations: Report includes next steps
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import asyncio
20
+ import csv
21
+ import io
22
+ import json
23
+ import random
24
+ import re
25
+ import time
26
+ from pathlib import Path
27
+ from typing import List, Optional
28
+
29
+ import httpx
30
+
31
+ from validation.base import (
32
+ DATA_DIR,
33
+ ValidationCase,
34
+ ValidationResult,
35
+ ValidationSummary,
36
+ ensure_data_dir,
37
+ fuzzy_match,
38
+ normalize_text,
39
+ print_summary,
40
+ run_cds_pipeline,
41
+ save_results,
42
+ )
43
+
44
+
45
+ # ──────────────────────────────────────────────
46
+ # Data fetching
47
+ # ──────────────────────────────────────────────
48
+
49
+ MTSAMPLES_URL = "https://raw.githubusercontent.com/socd06/medical-nlp/master/data/mtsamples.csv"
50
+ MTSAMPLES_FALLBACK_URL = "https://raw.githubusercontent.com/Abonia1/Clinical-NLP-on-MTSamples/master/mtsamples.csv"
51
+
52
+ # Specialties most relevant to CDS
53
+ RELEVANT_SPECIALTIES = {
54
+ "Cardiovascular / Pulmonary",
55
+ "Gastroenterology",
56
+ "General Medicine",
57
+ "Neurology",
58
+ "Orthopedic",
59
+ "Urology",
60
+ "Nephrology",
61
+ "Endocrinology",
62
+ "Hematology - Oncology",
63
+ "Obstetrics / Gynecology",
64
+ "Emergency Room Reports",
65
+ "Consult - History and Phy.",
66
+ "Discharge Summary",
67
+ "SOAP / Chart / Progress Notes",
68
+ "Internal Medicine",
69
+ }
70
+
71
+
72
+ async def fetch_mtsamples(
73
+ max_cases: int = 30,
74
+ seed: int = 42,
75
+ specialties: Optional[set] = None,
76
+ min_length: int = 200,
77
+ ) -> List[ValidationCase]:
78
+ """
79
+ Download MTSamples and convert to ValidationCase objects.
80
+
81
+ Args:
82
+ max_cases: Maximum number of cases to sample
83
+ seed: Random seed for reproducible sampling
84
+ specialties: Filter to these specialties (None = use RELEVANT_SPECIALTIES)
85
+ min_length: Minimum transcription length to include
86
+ """
87
+ ensure_data_dir()
88
+ cache_path = DATA_DIR / "mtsamples.csv"
89
+
90
+ if cache_path.exists():
91
+ print(f" Loading MTSamples from cache: {cache_path}")
92
+ raw_text = cache_path.read_text(encoding="utf-8")
93
+ else:
94
+ print(f" Downloading MTSamples...")
95
+ raw_text = await _download_mtsamples(cache_path)
96
+
97
+ if not raw_text:
98
+ raise RuntimeError("Failed to fetch MTSamples data.")
99
+
100
+ # Parse CSV
101
+ reader = csv.DictReader(io.StringIO(raw_text))
102
+ rows = list(reader)
103
+
104
+ # Filter
105
+ target_specialties = specialties or RELEVANT_SPECIALTIES
106
+ filtered = []
107
+ for row in rows:
108
+ specialty = row.get("medical_specialty", "").strip()
109
+ transcription = row.get("transcription", "").strip()
110
+ if not transcription or len(transcription) < min_length:
111
+ continue
112
+ if specialty in target_specialties:
113
+ filtered.append(row)
114
+
115
+ # Sample
116
+ random.seed(seed)
117
+ if len(filtered) > max_cases:
118
+ # Stratified sample: try to get cases from diverse specialties
119
+ by_specialty = {}
120
+ for row in filtered:
121
+ sp = row.get("medical_specialty", "Other")
122
+ by_specialty.setdefault(sp, []).append(row)
123
+
124
+ sampled = []
125
+ per_specialty = max(1, max_cases // len(by_specialty))
126
+ for sp, sp_rows in by_specialty.items():
127
+ sampled.extend(random.sample(sp_rows, min(per_specialty, len(sp_rows))))
128
+
129
+ # Fill remaining slots randomly
130
+ remaining = [r for r in filtered if r not in sampled]
131
+ if len(sampled) < max_cases and remaining:
132
+ sampled.extend(random.sample(remaining, min(max_cases - len(sampled), len(remaining))))
133
+
134
+ filtered = sampled[:max_cases]
135
+
136
+ # Convert to ValidationCase
137
+ cases = []
138
+ for i, row in enumerate(filtered):
139
+ transcription = row.get("transcription", "").strip()
140
+ specialty = row.get("medical_specialty", "Unknown").strip()
141
+ description = row.get("description", "").strip()
142
+ keywords = row.get("keywords", "").strip()
143
+
144
+ cases.append(ValidationCase(
145
+ case_id=f"mts_{i:04d}",
146
+ source_dataset="mtsamples",
147
+ input_text=transcription,
148
+ ground_truth={
149
+ "specialty": specialty,
150
+ "description": description,
151
+ "keywords": keywords,
152
+ },
153
+ metadata={
154
+ "sample_name": row.get("sample_name", ""),
155
+ "text_length": len(transcription),
156
+ },
157
+ ))
158
+
159
+ print(f" Loaded {len(cases)} MTSamples cases across {len(set(c.ground_truth['specialty'] for c in cases))} specialties")
160
+ return cases
161
+
162
+
163
+ async def _download_mtsamples(cache_path: Path) -> str:
164
+ """Download MTSamples CSV."""
165
+ async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
166
+ for url in [MTSAMPLES_URL, MTSAMPLES_FALLBACK_URL]:
167
+ try:
168
+ r = await client.get(url)
169
+ r.raise_for_status()
170
+ cache_path.write_text(r.text, encoding="utf-8")
171
+ print(f" Cached MTSamples ({len(r.text)} bytes) to {cache_path}")
172
+ return r.text
173
+ except Exception as e:
174
+ print(f" Warning: Failed to download from {url}: {e}")
175
+ continue
176
+ return ""
177
+
178
+
179
+ # ──────────────────────────────────────────────
180
+ # Scoring helpers
181
+ # ──────────────────────────────────────────────
182
+
183
+ SPECIALTY_KEYWORDS = {
184
+ "Cardiovascular / Pulmonary": ["cardiac", "heart", "coronary", "pulmonary", "lung", "chest", "hypertension", "arrhythmia"],
185
+ "Gastroenterology": ["gastro", "liver", "hepat", "colon", "bowel", "gi ", "abdominal", "pancrea"],
186
+ "General Medicine": ["general", "medicine", "primary", "routine"],
187
+ "Neurology": ["neuro", "brain", "seizure", "stroke", "headache", "neuropathy", "ms "],
188
+ "Orthopedic": ["ortho", "fracture", "bone", "joint", "knee", "hip", "shoulder", "spine"],
189
+ "Urology": ["urol", "kidney", "bladder", "prostate", "renal", "urinary"],
190
+ "Nephrology": ["renal", "kidney", "dialysis", "nephr", "creatinine"],
191
+ "Endocrinology": ["diabet", "thyroid", "endocrin", "insulin", "glucose", "adrenal"],
192
+ "Hematology - Oncology": ["cancer", "tumor", "leukemia", "lymphoma", "anemia", "oncol"],
193
+ "Obstetrics / Gynecology": ["pregnan", "obstet", "gynecol", "uterus", "ovarian", "menstrual"],
194
+ "Emergency Room Reports": ["emergency", "trauma", "acute", "er ", "ed "],
195
+ "Internal Medicine": ["internal", "medicine"],
196
+ }
197
+
198
+
199
+ def check_specialty_alignment(report_text: str, target_specialty: str) -> bool:
200
+ """Check if the report's content aligns with the expected specialty."""
201
+ keywords = SPECIALTY_KEYWORDS.get(target_specialty, [])
202
+ if not keywords:
203
+ return True # Can't check, assume aligned
204
+
205
+ report_lower = report_text.lower()
206
+ matches = sum(1 for kw in keywords if kw in report_lower)
207
+ return matches >= 1 # At least one specialty keyword present
208
+
209
+
210
+ def score_field_completeness(state) -> float:
211
+ """Score how many structured fields were successfully extracted from parsing."""
212
+ if not state or not state.patient_profile:
213
+ return 0.0
214
+
215
+ profile = state.patient_profile
216
+ fields = [
217
+ profile.age is not None,
218
+ profile.gender.value != "unknown",
219
+ bool(profile.chief_complaint),
220
+ bool(profile.history_of_present_illness),
221
+ len(profile.past_medical_history) > 0,
222
+ len(profile.current_medications) > 0,
223
+ len(profile.allergies) > 0,
224
+ len(profile.lab_results) > 0,
225
+ profile.vital_signs is not None,
226
+ bool(profile.social_history),
227
+ bool(profile.family_history),
228
+ ]
229
+ return sum(fields) / len(fields)
230
+
231
+
232
+ # ──────────────────────────────────────────────
233
+ # Validation harness
234
+ # ──────────────────────────────────────────────
235
+
236
+ async def validate_mtsamples(
237
+ cases: List[ValidationCase],
238
+ include_drug_check: bool = True,
239
+ include_guidelines: bool = True,
240
+ delay_between_cases: float = 2.0,
241
+ ) -> ValidationSummary:
242
+ """
243
+ Run MTSamples cases through the CDS pipeline and score results.
244
+ """
245
+ results: List[ValidationResult] = []
246
+ start_time = time.time()
247
+
248
+ for i, case in enumerate(cases):
249
+ specialty = case.ground_truth.get("specialty", "?")
250
+ print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty}): ", end="", flush=True)
251
+
252
+ case_start = time.monotonic()
253
+
254
+ state, report, error = await run_cds_pipeline(
255
+ patient_text=case.input_text,
256
+ include_drug_check=include_drug_check,
257
+ include_guidelines=include_guidelines,
258
+ )
259
+
260
+ elapsed_ms = int((time.monotonic() - case_start) * 1000)
261
+
262
+ # Step results
263
+ step_results = {}
264
+ if state:
265
+ step_results = {s.step_id: s.status.value for s in state.steps}
266
+
267
+ # Score
268
+ scores = {}
269
+ details = {}
270
+
271
+ # Parse success
272
+ scores["parse_success"] = 1.0 if (state and state.patient_profile) else 0.0
273
+
274
+ # Field completeness
275
+ scores["field_completeness"] = score_field_completeness(state)
276
+
277
+ if report:
278
+ # Has differential
279
+ scores["has_differential"] = 1.0 if len(report.differential_diagnosis) > 0 else 0.0
280
+
281
+ # Has recommendations
282
+ scores["has_recommendations"] = 1.0 if len(report.suggested_next_steps) > 0 else 0.0
283
+
284
+ # Has guideline recommendations
285
+ scores["has_guidelines"] = 1.0 if len(report.guideline_recommendations) > 0 else 0.0
286
+
287
+ # Specialty alignment
288
+ full_report_text = " ".join([
289
+ report.patient_summary or "",
290
+ " ".join(d.diagnosis for d in report.differential_diagnosis),
291
+ " ".join(report.guideline_recommendations),
292
+ " ".join(a.action for a in report.suggested_next_steps),
293
+ ])
294
+ scores["specialty_alignment"] = 1.0 if check_specialty_alignment(
295
+ full_report_text, specialty
296
+ ) else 0.0
297
+
298
+ # Conflict detection worked (if applicable)
299
+ if state and state.conflict_detection:
300
+ scores["conflict_detection_ran"] = 1.0
301
+ else:
302
+ scores["conflict_detection_ran"] = 0.0
303
+
304
+ details = {
305
+ "specialty": specialty,
306
+ "num_diagnoses": len(report.differential_diagnosis),
307
+ "num_recommendations": len(report.suggested_next_steps),
308
+ "field_completeness": scores["field_completeness"],
309
+ "num_conflicts": len(report.conflicts) if report.conflicts else 0,
310
+ }
311
+
312
+ print(f"βœ“ fields={scores['field_completeness']:.0%} dx={len(report.differential_diagnosis)} ({elapsed_ms}ms)")
313
+ else:
314
+ scores.update({
315
+ "has_differential": 0.0,
316
+ "has_recommendations": 0.0,
317
+ "has_guidelines": 0.0,
318
+ "specialty_alignment": 0.0,
319
+ "conflict_detection_ran": 0.0,
320
+ })
321
+ details = {"specialty": specialty, "error": error}
322
+ print(f"βœ— FAILED: {error[:80] if error else 'unknown'}")
323
+
324
+ results.append(ValidationResult(
325
+ case_id=case.case_id,
326
+ source_dataset="mtsamples",
327
+ success=report is not None,
328
+ scores=scores,
329
+ pipeline_time_ms=elapsed_ms,
330
+ step_results=step_results,
331
+ report_summary=report.patient_summary[:200] if report else None,
332
+ error=error,
333
+ details=details,
334
+ ))
335
+
336
+ if i < len(cases) - 1:
337
+ await asyncio.sleep(delay_between_cases)
338
+
339
+ # Aggregate
340
+ total = len(results)
341
+ successful = sum(1 for r in results if r.success)
342
+
343
+ metric_names = [
344
+ "parse_success", "field_completeness", "has_differential",
345
+ "has_recommendations", "has_guidelines", "specialty_alignment",
346
+ "conflict_detection_ran",
347
+ ]
348
+ metrics = {}
349
+ for m in metric_names:
350
+ values = [r.scores.get(m, 0.0) for r in results]
351
+ metrics[m] = sum(values) / len(values) if values else 0.0
352
+
353
+ times = [r.pipeline_time_ms for r in results if r.success]
354
+ metrics["avg_pipeline_time_ms"] = sum(times) / len(times) if times else 0
355
+
356
+ summary = ValidationSummary(
357
+ dataset="mtsamples",
358
+ total_cases=total,
359
+ successful_cases=successful,
360
+ failed_cases=total - successful,
361
+ metrics=metrics,
362
+ per_case=results,
363
+ run_duration_sec=time.time() - start_time,
364
+ )
365
+
366
+ return summary
367
+
368
+
369
+ # ──────────────────────────────────────────────
370
+ # Standalone runner
371
+ # ──────────────────────────────────────────────
372
+
373
+ async def main():
374
+ """Run MTSamples validation standalone."""
375
+ import argparse
376
+
377
+ parser = argparse.ArgumentParser(description="MTSamples Validation")
378
+ parser.add_argument("--max-cases", type=int, default=10, help="Number of cases to evaluate")
379
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
380
+ parser.add_argument("--no-drugs", action="store_true", help="Skip drug interaction check")
381
+ parser.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
382
+ parser.add_argument("--delay", type=float, default=2.0, help="Delay between cases (seconds)")
383
+ args = parser.parse_args()
384
+
385
+ print("MTSamples Validation Harness")
386
+ print("=" * 40)
387
+
388
+ cases = await fetch_mtsamples(max_cases=args.max_cases, seed=args.seed)
389
+ summary = await validate_mtsamples(
390
+ cases,
391
+ include_drug_check=not args.no_drugs,
392
+ include_guidelines=not args.no_guidelines,
393
+ delay_between_cases=args.delay,
394
+ )
395
+
396
+ print_summary(summary)
397
+ path = save_results(summary)
398
+ print(f"Results saved to: {path}")
399
+
400
+
401
+ if __name__ == "__main__":
402
+ asyncio.run(main())
src/backend/validation/harness_pmc.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PMC Case Reports fetcher and validation harness.
3
+
4
+ Fetches published clinical case reports from PubMed Central and evaluates
5
+ the CDS pipeline's diagnostic accuracy against gold-standard diagnoses.
6
+
7
+ Source: NCBI PubMed / PubMed Central (E-utilities API)
8
+ Format: XML abstracts with case presentations and final diagnoses
9
+
10
+ Metrics:
11
+ - diagnostic_accuracy: Correct diagnosis appears in differential
12
+ - top3_accuracy: Correct diagnosis in top 3
13
+ - parse_success_rate: Pipeline completed without crashing
14
+ - has_recommendations: Report includes actionable next steps
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import json
20
+ import random
21
+ import re
22
+ import time
23
+ import xml.etree.ElementTree as ET
24
+ from pathlib import Path
25
+ from typing import List, Optional, Tuple
26
+
27
+ import httpx
28
+
29
+ from validation.base import (
30
+ DATA_DIR,
31
+ ValidationCase,
32
+ ValidationResult,
33
+ ValidationSummary,
34
+ diagnosis_in_differential,
35
+ ensure_data_dir,
36
+ fuzzy_match,
37
+ normalize_text,
38
+ print_summary,
39
+ run_cds_pipeline,
40
+ save_results,
41
+ )
42
+
43
+
44
+ # ──────────────────────────────────────────────
45
+ # NCBI E-utilities configuration
46
+ # ──────────────────────────────────────────────
47
+
48
+ EUTILS_BASE = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
49
+ ESEARCH_URL = f"{EUTILS_BASE}/esearch.fcgi"
50
+ EFETCH_URL = f"{EUTILS_BASE}/efetch.fcgi"
51
+
52
+ # Curated search queries for case reports with clear diagnoses
53
+ # Each tuple: (search_query, expected_specialty)
54
+ CASE_REPORT_QUERIES = [
55
+ ('"case report"[Title] AND "myocardial infarction"[Title] AND diagnosis', "Cardiology"),
56
+ ('"case report"[Title] AND "pneumonia"[Title] AND diagnosis', "Pulmonology"),
57
+ ('"case report"[Title] AND "diabetic ketoacidosis"[Title]', "Endocrinology"),
58
+ ('"case report"[Title] AND "stroke"[Title] AND diagnosis', "Neurology"),
59
+ ('"case report"[Title] AND "appendicitis"[Title] AND diagnosis', "Surgery"),
60
+ ('"case report"[Title] AND "pulmonary embolism"[Title]', "Pulmonology"),
61
+ ('"case report"[Title] AND "sepsis"[Title] AND management', "Critical Care"),
62
+ ('"case report"[Title] AND "heart failure"[Title] AND management', "Cardiology"),
63
+ ('"case report"[Title] AND "pancreatitis"[Title] AND diagnosis', "Gastroenterology"),
64
+ ('"case report"[Title] AND "meningitis"[Title] AND diagnosis', "Neurology/ID"),
65
+ ('"case report"[Title] AND "urinary tract infection"[Title]', "Urology/ID"),
66
+ ('"case report"[Title] AND "thyroid"[Title] AND "nodule"', "Endocrinology"),
67
+ ('"case report"[Title] AND "deep vein thrombosis"[Title]', "Hematology"),
68
+ ('"case report"[Title] AND "anaphylaxis"[Title]', "Allergy/EM"),
69
+ ('"case report"[Title] AND "renal failure"[Title] AND acute', "Nephrology"),
70
+ ('"case report"[Title] AND "liver cirrhosis"[Title]', "Hepatology"),
71
+ ('"case report"[Title] AND "asthma"[Title] AND exacerbation', "Pulmonology"),
72
+ ('"case report"[Title] AND "seizure"[Title] AND diagnosis', "Neurology"),
73
+ ('"case report"[Title] AND "hypoglycemia"[Title]', "Endocrinology"),
74
+ ('"case report"[Title] AND "gastrointestinal bleeding"[Title]', "Gastroenterology"),
75
+ ]
76
+
77
+
78
+ async def fetch_pmc_cases(
79
+ max_cases: int = 20,
80
+ seed: int = 42,
81
+ ) -> List[ValidationCase]:
82
+ """
83
+ Fetch case reports from PubMed and convert to ValidationCase objects.
84
+
85
+ Uses PubMed E-utilities to search for case reports with clear diagnoses,
86
+ then extracts the clinical presentation and diagnosis from abstracts.
87
+
88
+ Args:
89
+ max_cases: Maximum number of cases to fetch
90
+ seed: Random seed for reproducible selection
91
+ """
92
+ ensure_data_dir()
93
+ cache_path = DATA_DIR / "pmc_cases.json"
94
+
95
+ if cache_path.exists():
96
+ print(f" Loading PMC cases from cache: {cache_path}")
97
+ cached = json.loads(cache_path.read_text(encoding="utf-8"))
98
+ cases = [ValidationCase(**c) for c in cached]
99
+ if len(cases) >= max_cases:
100
+ random.seed(seed)
101
+ return random.sample(cases, min(max_cases, len(cases)))
102
+ # Fall through to fetch more if cache is insufficient
103
+
104
+ print(f" Fetching case reports from PubMed...")
105
+ cases = await _fetch_from_pubmed(max_cases, seed)
106
+
107
+ if cases:
108
+ # Cache
109
+ cached_data = [
110
+ {
111
+ "case_id": c.case_id,
112
+ "source_dataset": c.source_dataset,
113
+ "input_text": c.input_text,
114
+ "ground_truth": c.ground_truth,
115
+ "metadata": c.metadata,
116
+ }
117
+ for c in cases
118
+ ]
119
+ cache_path.write_text(json.dumps(cached_data, indent=2), encoding="utf-8")
120
+ print(f" Cached {len(cases)} PMC cases to {cache_path}")
121
+
122
+ print(f" Loaded {len(cases)} PMC case reports")
123
+ return cases
124
+
125
+
126
+ async def _fetch_from_pubmed(max_cases: int, seed: int) -> List[ValidationCase]:
127
+ """Fetch case reports via PubMed E-utilities."""
128
+ cases = []
129
+ random.seed(seed)
130
+ queries = random.sample(CASE_REPORT_QUERIES, min(max_cases, len(CASE_REPORT_QUERIES)))
131
+
132
+ async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
133
+ for query_text, specialty in queries:
134
+ if len(cases) >= max_cases:
135
+ break
136
+
137
+ try:
138
+ # Step 1: Search for PMIDs
139
+ pmids = await _esearch(client, query_text, retmax=3)
140
+ if not pmids:
141
+ continue
142
+
143
+ # Step 2: Fetch abstracts
144
+ for pmid in pmids[:1]: # Take first result per query
145
+ abstract_data = await _efetch_abstract(client, pmid)
146
+ if not abstract_data:
147
+ continue
148
+
149
+ title, abstract = abstract_data
150
+
151
+ # Step 3: Extract case presentation and diagnosis
152
+ presentation, diagnosis = _extract_case_and_diagnosis(title, abstract, query_text)
153
+ if not presentation or not diagnosis:
154
+ continue
155
+
156
+ cases.append(ValidationCase(
157
+ case_id=f"pmc_{pmid}",
158
+ source_dataset="pmc",
159
+ input_text=presentation,
160
+ ground_truth={
161
+ "diagnosis": diagnosis,
162
+ "specialty": specialty,
163
+ "title": title,
164
+ },
165
+ metadata={
166
+ "pmid": pmid,
167
+ "full_abstract": abstract,
168
+ },
169
+ ))
170
+
171
+ if len(cases) >= max_cases:
172
+ break
173
+
174
+ # NCBI rate limit: max 3 requests/second without API key
175
+ await asyncio.sleep(0.5)
176
+
177
+ except Exception as e:
178
+ print(f" Warning: Query failed '{query_text[:40]}...': {e}")
179
+ continue
180
+
181
+ return cases
182
+
183
+
184
+ async def _esearch(client: httpx.AsyncClient, query: str, retmax: int = 3) -> List[str]:
185
+ """Search PubMed and return PMIDs."""
186
+ params = {
187
+ "db": "pubmed",
188
+ "term": query,
189
+ "retmax": retmax,
190
+ "retmode": "json",
191
+ "sort": "relevance",
192
+ }
193
+ r = await client.get(ESEARCH_URL, params=params)
194
+ r.raise_for_status()
195
+ data = r.json()
196
+ return data.get("esearchresult", {}).get("idlist", [])
197
+
198
+
199
+ async def _efetch_abstract(client: httpx.AsyncClient, pmid: str) -> Optional[Tuple[str, str]]:
200
+ """Fetch the title and abstract for a PMID."""
201
+ params = {
202
+ "db": "pubmed",
203
+ "id": pmid,
204
+ "retmode": "xml",
205
+ }
206
+ r = await client.get(EFETCH_URL, params=params)
207
+ r.raise_for_status()
208
+
209
+ try:
210
+ root = ET.fromstring(r.text)
211
+
212
+ # Extract title
213
+ title_el = root.find(".//ArticleTitle")
214
+ title = title_el.text if title_el is not None and title_el.text else ""
215
+
216
+ # Extract abstract
217
+ abstract_parts = []
218
+ for abs_text in root.findall(".//AbstractText"):
219
+ label = abs_text.get("Label", "")
220
+ text = abs_text.text or ""
221
+ # Collect tail text from sub-elements
222
+ full_text = (abs_text.text or "") + "".join(
223
+ (child.text or "") + (child.tail or "") for child in abs_text
224
+ )
225
+ if label:
226
+ abstract_parts.append(f"{label}: {full_text.strip()}")
227
+ else:
228
+ abstract_parts.append(full_text.strip())
229
+
230
+ abstract = " ".join(abstract_parts)
231
+
232
+ if len(abstract) < 100:
233
+ return None
234
+
235
+ return title, abstract
236
+
237
+ except ET.ParseError:
238
+ return None
239
+
240
+
241
+ def _extract_case_and_diagnosis(
242
+ title: str, abstract: str, search_query: str
243
+ ) -> Tuple[Optional[str], Optional[str]]:
244
+ """
245
+ Extract the clinical presentation and final diagnosis from a case report abstract.
246
+
247
+ Strategy:
248
+ 1. Try structured abstract sections (CASE PRESENTATION, DIAGNOSIS, etc.)
249
+ 2. Extract diagnosis from the title (common pattern: "A case of [diagnosis]")
250
+ 3. Fall back to using the search condition as the expected diagnosis
251
+ """
252
+ # Try to extract diagnosis from title
253
+ diagnosis = None
254
+ title_patterns = [
255
+ r"case (?:report )?of (.+?)(?:\.|:|$)",
256
+ r"presenting (?:as|with) (.+?)(?:\.|:|$)",
257
+ r"diagnosed (?:as|with) (.+?)(?:\.|:|$)",
258
+ r"rare case of (.+?)(?:\.|:|$)",
259
+ r"unusual (?:case|presentation) of (.+?)(?:\.|:|$)",
260
+ # Pattern: "Diagnosis Name: A Case Report"
261
+ r"^(.+?):\s*[Aa]\s*[Cc]ase\s*[Rr]eport",
262
+ # Pattern: "Diagnosis Name - Case Report"
263
+ r"^(.+?)\s*[-–—]\s*[Cc]ase\s*[Rr]eport",
264
+ # Pattern: "Case of Diagnosis Name"
265
+ r"[Cc]ase\s+of\s+(.+?)(?:\.|:|,|$)",
266
+ ]
267
+ for pattern in title_patterns:
268
+ match = re.search(pattern, title, re.IGNORECASE)
269
+ if match:
270
+ diagnosis = match.group(1).strip()
271
+ break
272
+
273
+ if not diagnosis:
274
+ # Extract from search query
275
+ # queries look like: '"case report"[Title] AND "myocardial infarction"[Title]'
276
+ # Find all quoted terms and pick the one that isn't "case report"
277
+ matches = re.findall(r'"([^"]+)"', search_query)
278
+ for m in matches:
279
+ if m.lower() != "case report":
280
+ diagnosis = m
281
+ break
282
+
283
+ if not diagnosis:
284
+ return None, None
285
+
286
+ # Clean diagnosis text
287
+ diagnosis = diagnosis.strip().rstrip('.')
288
+
289
+ # Extract clinical presentation
290
+ # For structured abstracts, look for specific sections
291
+ presentation_sections = ["CASE PRESENTATION", "CASE REPORT", "CASE", "CLINICAL PRESENTATION", "HISTORY"]
292
+ conclusion_sections = ["CONCLUSION", "DISCUSSION", "OUTCOME", "DIAGNOSIS", "RESULTS"]
293
+
294
+ # Try to split abstract into presentation vs conclusion
295
+ presentation = abstract
296
+
297
+ # Look for section boundaries in structured abstracts
298
+ for cs in conclusion_sections:
299
+ pattern = re.compile(rf'\b{cs}\b[:\s]', re.IGNORECASE)
300
+ match = pattern.search(abstract)
301
+ if match:
302
+ # Everything before the conclusion is the presentation
303
+ candidate = abstract[:match.start()].strip()
304
+ if len(candidate) > 100:
305
+ presentation = candidate
306
+ break
307
+
308
+ # Clean up
309
+ presentation = presentation.strip()
310
+ if len(presentation) < 50:
311
+ presentation = abstract # Use full abstract if extraction is too short
312
+
313
+ return presentation, diagnosis
314
+
315
+
316
+ # ──────────────────────────────────────────────
317
+ # Validation harness
318
+ # ──────────────────────────────────────────────
319
+
320
+ async def validate_pmc(
321
+ cases: List[ValidationCase],
322
+ include_drug_check: bool = True,
323
+ include_guidelines: bool = True,
324
+ delay_between_cases: float = 2.0,
325
+ ) -> ValidationSummary:
326
+ """
327
+ Run PMC case reports through the CDS pipeline and score results.
328
+ """
329
+ results: List[ValidationResult] = []
330
+ start_time = time.time()
331
+
332
+ for i, case in enumerate(cases):
333
+ dx = case.ground_truth.get("diagnosis", "?")
334
+ specialty = case.ground_truth.get("specialty", "?")
335
+ print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty} β€” {dx[:40]}): ", end="", flush=True)
336
+
337
+ case_start = time.monotonic()
338
+
339
+ state, report, error = await run_cds_pipeline(
340
+ patient_text=case.input_text,
341
+ include_drug_check=include_drug_check,
342
+ include_guidelines=include_guidelines,
343
+ )
344
+
345
+ elapsed_ms = int((time.monotonic() - case_start) * 1000)
346
+
347
+ step_results = {}
348
+ if state:
349
+ step_results = {s.step_id: s.status.value for s in state.steps}
350
+
351
+ scores = {}
352
+ details = {}
353
+ target_diagnosis = case.ground_truth["diagnosis"]
354
+
355
+ if report:
356
+ # Diagnostic accuracy (anywhere in differential)
357
+ found_any, rank_any = diagnosis_in_differential(target_diagnosis, report)
358
+ scores["diagnostic_accuracy"] = 1.0 if found_any else 0.0
359
+
360
+ # Top-3 accuracy
361
+ found_top3, rank3 = diagnosis_in_differential(target_diagnosis, report, top_n=3)
362
+ scores["top3_accuracy"] = 1.0 if found_top3 else 0.0
363
+
364
+ # Top-1 accuracy
365
+ found_top1, rank1 = diagnosis_in_differential(target_diagnosis, report, top_n=1)
366
+ scores["top1_accuracy"] = 1.0 if found_top1 else 0.0
367
+
368
+ # Parse success
369
+ scores["parse_success"] = 1.0
370
+
371
+ # Has recommendations
372
+ scores["has_recommendations"] = 1.0 if len(report.suggested_next_steps) > 0 else 0.0
373
+
374
+ details = {
375
+ "target_diagnosis": target_diagnosis,
376
+ "top_diagnosis": report.differential_diagnosis[0].diagnosis if report.differential_diagnosis else "NONE",
377
+ "num_diagnoses": len(report.differential_diagnosis),
378
+ "found_at_rank": rank_any if found_any else -1,
379
+ "all_diagnoses": [d.diagnosis for d in report.differential_diagnosis[:5]],
380
+ }
381
+
382
+ icon = "βœ“" if found_any else "βœ—"
383
+ top_dx = report.differential_diagnosis[0].diagnosis if report.differential_diagnosis else "NONE"
384
+ print(f"{icon} top1={'Y' if found_top1 else 'N'} diag={'Y' if found_any else 'N'} | top: {top_dx[:30]} ({elapsed_ms}ms)")
385
+ else:
386
+ scores = {
387
+ "diagnostic_accuracy": 0.0,
388
+ "top3_accuracy": 0.0,
389
+ "top1_accuracy": 0.0,
390
+ "parse_success": 0.0,
391
+ "has_recommendations": 0.0,
392
+ }
393
+ details = {"target_diagnosis": target_diagnosis, "error": error}
394
+ print(f"βœ— FAILED: {error[:80] if error else 'unknown'}")
395
+
396
+ results.append(ValidationResult(
397
+ case_id=case.case_id,
398
+ source_dataset="pmc",
399
+ success=report is not None,
400
+ scores=scores,
401
+ pipeline_time_ms=elapsed_ms,
402
+ step_results=step_results,
403
+ report_summary=report.patient_summary[:200] if report else None,
404
+ error=error,
405
+ details=details,
406
+ ))
407
+
408
+ if i < len(cases) - 1:
409
+ await asyncio.sleep(delay_between_cases)
410
+
411
+ # Aggregate
412
+ total = len(results)
413
+ successful = sum(1 for r in results if r.success)
414
+
415
+ metric_names = ["diagnostic_accuracy", "top3_accuracy", "top1_accuracy", "parse_success", "has_recommendations"]
416
+ metrics = {}
417
+ for m in metric_names:
418
+ values = [r.scores.get(m, 0.0) for r in results]
419
+ metrics[m] = sum(values) / len(values) if values else 0.0
420
+
421
+ times = [r.pipeline_time_ms for r in results if r.success]
422
+ metrics["avg_pipeline_time_ms"] = sum(times) / len(times) if times else 0
423
+
424
+ summary = ValidationSummary(
425
+ dataset="pmc",
426
+ total_cases=total,
427
+ successful_cases=successful,
428
+ failed_cases=total - successful,
429
+ metrics=metrics,
430
+ per_case=results,
431
+ run_duration_sec=time.time() - start_time,
432
+ )
433
+
434
+ return summary
435
+
436
+
437
+ # ──────────────────────────────────────────────
438
+ # Standalone runner
439
+ # ──────────────────────────────────────────────
440
+
441
+ async def main():
442
+ """Run PMC Case Reports validation standalone."""
443
+ import argparse
444
+
445
+ parser = argparse.ArgumentParser(description="PMC Case Reports Validation")
446
+ parser.add_argument("--max-cases", type=int, default=10, help="Number of cases to evaluate")
447
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
448
+ parser.add_argument("--no-drugs", action="store_true", help="Skip drug interaction check")
449
+ parser.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
450
+ parser.add_argument("--delay", type=float, default=2.0, help="Delay between cases (seconds)")
451
+ args = parser.parse_args()
452
+
453
+ print("PMC Case Reports Validation Harness")
454
+ print("=" * 40)
455
+
456
+ cases = await fetch_pmc_cases(max_cases=args.max_cases, seed=args.seed)
457
+ summary = await validate_pmc(
458
+ cases,
459
+ include_drug_check=not args.no_drugs,
460
+ include_guidelines=not args.no_guidelines,
461
+ delay_between_cases=args.delay,
462
+ )
463
+
464
+ print_summary(summary)
465
+ path = save_results(summary)
466
+ print(f"Results saved to: {path}")
467
+
468
+
469
+ if __name__ == "__main__":
470
+ asyncio.run(main())
src/backend/validation/run_validation.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified validation runner for the Clinical Decision Support Agent.
3
+
4
+ Runs all three dataset validations (MedQA, MTSamples, PMC Case Reports)
5
+ and produces a combined summary report.
6
+
7
+ Usage:
8
+ # From src/backend directory:
9
+ python -m validation.run_validation --all --max-cases 10
10
+ python -m validation.run_validation --medqa --max-cases 20
11
+ python -m validation.run_validation --mtsamples --max-cases 15
12
+ python -m validation.run_validation --pmc --max-cases 10
13
+
14
+ # Fetch data only (no pipeline execution):
15
+ python -m validation.run_validation --fetch-only
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import asyncio
20
+ import json
21
+ import sys
22
+ import time
23
+ from datetime import datetime, timezone
24
+ from pathlib import Path
25
+
26
+ # Ensure backend is importable
27
+ BACKEND_DIR = Path(__file__).resolve().parent.parent
28
+ if str(BACKEND_DIR) not in sys.path:
29
+ sys.path.insert(0, str(BACKEND_DIR))
30
+
31
+ from validation.base import (
32
+ ValidationSummary,
33
+ print_summary,
34
+ save_results,
35
+ )
36
+ from validation.harness_medqa import fetch_medqa, validate_medqa
37
+ from validation.harness_mtsamples import fetch_mtsamples, validate_mtsamples
38
+ from validation.harness_pmc import fetch_pmc_cases, validate_pmc
39
+
40
+
41
+ async def run_all_validations(
42
+ run_medqa: bool = True,
43
+ run_mtsamples: bool = True,
44
+ run_pmc: bool = True,
45
+ max_cases: int = 10,
46
+ seed: int = 42,
47
+ include_drug_check: bool = True,
48
+ include_guidelines: bool = True,
49
+ delay: float = 2.0,
50
+ fetch_only: bool = False,
51
+ ) -> dict:
52
+ """
53
+ Run validation against selected datasets.
54
+
55
+ Returns dict of {dataset_name: ValidationSummary}
56
+ """
57
+ results = {}
58
+ start = time.time()
59
+
60
+ # ── MedQA ──
61
+ if run_medqa:
62
+ print("\n" + "=" * 60)
63
+ print(" DATASET 1: MedQA (USMLE-style diagnostic accuracy)")
64
+ print("=" * 60)
65
+
66
+ cases = await fetch_medqa(max_cases=max_cases, seed=seed)
67
+
68
+ if fetch_only:
69
+ print(f" Fetched {len(cases)} MedQA cases (fetch-only mode)")
70
+ else:
71
+ summary = await validate_medqa(
72
+ cases,
73
+ include_drug_check=include_drug_check,
74
+ include_guidelines=include_guidelines,
75
+ delay_between_cases=delay,
76
+ )
77
+ print_summary(summary)
78
+ save_results(summary)
79
+ results["medqa"] = summary
80
+
81
+ # ── MTSamples ──
82
+ if run_mtsamples:
83
+ print("\n" + "=" * 60)
84
+ print(" DATASET 2: MTSamples (clinical note parsing robustness)")
85
+ print("=" * 60)
86
+
87
+ cases = await fetch_mtsamples(max_cases=max_cases, seed=seed)
88
+
89
+ if fetch_only:
90
+ print(f" Fetched {len(cases)} MTSamples cases (fetch-only mode)")
91
+ else:
92
+ summary = await validate_mtsamples(
93
+ cases,
94
+ include_drug_check=include_drug_check,
95
+ include_guidelines=include_guidelines,
96
+ delay_between_cases=delay,
97
+ )
98
+ print_summary(summary)
99
+ save_results(summary)
100
+ results["mtsamples"] = summary
101
+
102
+ # ── PMC Case Reports ──
103
+ if run_pmc:
104
+ print("\n" + "=" * 60)
105
+ print(" DATASET 3: PMC Case Reports (real-world diagnostic accuracy)")
106
+ print("=" * 60)
107
+
108
+ cases = await fetch_pmc_cases(max_cases=max_cases, seed=seed)
109
+
110
+ if fetch_only:
111
+ print(f" Fetched {len(cases)} PMC cases (fetch-only mode)")
112
+ else:
113
+ summary = await validate_pmc(
114
+ cases,
115
+ include_drug_check=include_drug_check,
116
+ include_guidelines=include_guidelines,
117
+ delay_between_cases=delay,
118
+ )
119
+ print_summary(summary)
120
+ save_results(summary)
121
+ results["pmc"] = summary
122
+
123
+ # ── Combined Summary ──
124
+ total_duration = time.time() - start
125
+
126
+ if results and not fetch_only:
127
+ _print_combined_summary(results, total_duration)
128
+ _save_combined_report(results, total_duration)
129
+
130
+ return results
131
+
132
+
133
+ def _print_combined_summary(results: dict, total_duration: float):
134
+ """Print a combined summary across all datasets."""
135
+ print("\n" + "=" * 70)
136
+ print(" COMBINED VALIDATION REPORT")
137
+ print("=" * 70)
138
+
139
+ # Header
140
+ print(f"\n {'Dataset':<15} {'Cases':>6} {'Success':>8} {'Key Metric':>25} {'Value':>8}")
141
+ print(f" {'-'*15} {'-'*6} {'-'*8} {'-'*25} {'-'*8}")
142
+
143
+ for name, summary in results.items():
144
+ # Pick the most important metric for each dataset
145
+ if name == "medqa":
146
+ key_metric = "top3_accuracy"
147
+ elif name == "mtsamples":
148
+ key_metric = "parse_success"
149
+ elif name == "pmc":
150
+ key_metric = "diagnostic_accuracy"
151
+ else:
152
+ key_metric = list(summary.metrics.keys())[0] if summary.metrics else "N/A"
153
+
154
+ value = summary.metrics.get(key_metric, 0.0)
155
+ print(
156
+ f" {name:<15} {summary.total_cases:>6} "
157
+ f"{summary.successful_cases:>8} "
158
+ f"{key_metric:>25} {value:>7.1%}"
159
+ )
160
+
161
+ # All metrics
162
+ print(f"\n {'─' * 66}")
163
+ for name, summary in results.items():
164
+ print(f"\n {name.upper()} metrics:")
165
+ for metric, value in sorted(summary.metrics.items()):
166
+ if "time" in metric and isinstance(value, (int, float)):
167
+ print(f" {metric:<35} {value:.0f}ms")
168
+ elif isinstance(value, float):
169
+ print(f" {metric:<35} {value:.1%}")
170
+
171
+ # Totals
172
+ total_cases = sum(s.total_cases for s in results.values())
173
+ total_success = sum(s.successful_cases for s in results.values())
174
+ print(f"\n Total cases: {total_cases}")
175
+ print(f" Total success: {total_success}")
176
+ print(f" Total duration: {total_duration:.1f}s ({total_duration/60:.1f}min)")
177
+ print(f" Timestamp: {datetime.now(timezone.utc).isoformat()}")
178
+ print("=" * 70)
179
+
180
+
181
+ def _save_combined_report(results: dict, total_duration: float):
182
+ """Save combined report to JSON."""
183
+ results_dir = Path(__file__).resolve().parent / "results"
184
+ results_dir.mkdir(parents=True, exist_ok=True)
185
+
186
+ ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
187
+ path = results_dir / f"combined_{ts}.json"
188
+
189
+ combined = {
190
+ "timestamp": datetime.now(timezone.utc).isoformat(),
191
+ "total_duration_sec": total_duration,
192
+ "datasets": {},
193
+ }
194
+
195
+ for name, summary in results.items():
196
+ combined["datasets"][name] = {
197
+ "total_cases": summary.total_cases,
198
+ "successful_cases": summary.successful_cases,
199
+ "failed_cases": summary.failed_cases,
200
+ "metrics": summary.metrics,
201
+ "run_duration_sec": summary.run_duration_sec,
202
+ }
203
+
204
+ path.write_text(json.dumps(combined, indent=2, default=str))
205
+ print(f"\n Combined report saved to: {path}")
206
+
207
+
208
+ def main():
209
+ """CLI entry point."""
210
+ import argparse
211
+
212
+ parser = argparse.ArgumentParser(
213
+ description="CDS Agent Validation Suite",
214
+ formatter_class=argparse.RawDescriptionHelpFormatter,
215
+ epilog="""
216
+ Examples:
217
+ python -m validation.run_validation --all --max-cases 10
218
+ python -m validation.run_validation --medqa --max-cases 50
219
+ python -m validation.run_validation --fetch-only
220
+ python -m validation.run_validation --medqa --pmc --max-cases 20 --no-drugs
221
+ """,
222
+ )
223
+
224
+ # Dataset selection
225
+ data_group = parser.add_argument_group("Datasets")
226
+ data_group.add_argument("--all", action="store_true", help="Run all three datasets")
227
+ data_group.add_argument("--medqa", action="store_true", help="Run MedQA validation")
228
+ data_group.add_argument("--mtsamples", action="store_true", help="Run MTSamples validation")
229
+ data_group.add_argument("--pmc", action="store_true", help="Run PMC Case Reports validation")
230
+
231
+ # Configuration
232
+ config_group = parser.add_argument_group("Configuration")
233
+ config_group.add_argument("--max-cases", type=int, default=10, help="Cases per dataset (default: 10)")
234
+ config_group.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
235
+ config_group.add_argument("--delay", type=float, default=2.0, help="Delay between cases in seconds (default: 2.0)")
236
+ config_group.add_argument("--no-drugs", action="store_true", help="Skip drug interaction checks")
237
+ config_group.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
238
+ config_group.add_argument("--fetch-only", action="store_true", help="Only download data, don't run pipeline")
239
+
240
+ args = parser.parse_args()
241
+
242
+ # Default to --all if nothing specified
243
+ if not any([args.all, args.medqa, args.mtsamples, args.pmc]):
244
+ args.all = True
245
+
246
+ run_medqa = args.all or args.medqa
247
+ run_mtsamples = args.all or args.mtsamples
248
+ run_pmc = args.all or args.pmc
249
+
250
+ print("╔════════════════════════════════════════════════════════╗")
251
+ print("β•‘ Clinical Decision Support Agent β€” Validation Suite β•‘")
252
+ print("β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•")
253
+ print(f"\n Datasets: {'MedQA ' if run_medqa else ''}{'MTSamples ' if run_mtsamples else ''}{'PMC ' if run_pmc else ''}")
254
+ print(f" Cases/dataset: {args.max_cases}")
255
+ print(f" Drug check: {'Yes' if not args.no_drugs else 'No'}")
256
+ print(f" Guidelines: {'Yes' if not args.no_guidelines else 'No'}")
257
+ print(f" Fetch only: {'Yes' if args.fetch_only else 'No'}")
258
+
259
+ asyncio.run(run_all_validations(
260
+ run_medqa=run_medqa,
261
+ run_mtsamples=run_mtsamples,
262
+ run_pmc=run_pmc,
263
+ max_cases=args.max_cases,
264
+ seed=args.seed,
265
+ include_drug_check=not args.no_drugs,
266
+ include_guidelines=not args.no_guidelines,
267
+ delay=args.delay,
268
+ fetch_only=args.fetch_only,
269
+ ))
270
+
271
+
272
+ if __name__ == "__main__":
273
+ main()