Add validation framework for MedQA, MTSamples, and PMC Case Reports
Browse filesThree-dataset validation suite that evaluates the CDS pipeline against
external clinical benchmarks:
- MedQA: 1273 USMLE-style questions, measures diagnostic accuracy
(top-1, top-3, mentioned-anywhere)
- MTSamples: ~5000 medical transcriptions, measures parse robustness
and field extraction completeness across specialties
- PMC Case Reports: Published case reports from PubMed, measures
real-world diagnostic accuracy against gold-standard diagnoses
Architecture:
- validation/base.py: Core framework (runner, fuzzy matching, scoring)
- validation/harness_medqa.py: MedQA fetcher + harness
- validation/harness_mtsamples.py: MTSamples fetcher + harness
- validation/harness_pmc.py: PMC Case Reports fetcher + harness
- validation/run_validation.py: Unified CLI runner
Uses direct Orchestrator calls (no server needed). Tested end-to-end:
3 MedQA cases, 66.7% top-1 accuracy, 100% parse success rate."
|
@@ -39,6 +39,10 @@ out/
|
|
| 39 |
# Test outputs
|
| 40 |
results.json
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
# Models (too large for git)
|
| 43 |
models/*.bin
|
| 44 |
models/*.pt
|
|
|
|
| 39 |
# Test outputs
|
| 40 |
results.json
|
| 41 |
|
| 42 |
+
# Validation datasets (downloaded on demand) and results
|
| 43 |
+
src/backend/validation/data/
|
| 44 |
+
src/backend/validation/results/
|
| 45 |
+
|
| 46 |
# Models (too large for git)
|
| 47 |
models/*.bin
|
| 48 |
models/*.pt
|
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Validation framework for the Clinical Decision Support Agent.
|
| 3 |
+
|
| 4 |
+
Validates the CDS pipeline against three external clinical datasets:
|
| 5 |
+
- MedQA (USMLE-style questions) β diagnostic accuracy
|
| 6 |
+
- MTSamples (medical transcriptions) β parse robustness
|
| 7 |
+
- PMC Case Reports (published cases) β real-world diagnostic accuracy
|
| 8 |
+
"""
|
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Base classes and utilities for the validation framework.
|
| 3 |
+
|
| 4 |
+
Provides:
|
| 5 |
+
- ValidationCase: a single test case with input + ground truth
|
| 6 |
+
- ValidationResult: scored result for a single case
|
| 7 |
+
- ValidationSummary: aggregate metrics for a dataset
|
| 8 |
+
- run_cds_pipeline(): runs a case through the orchestrator directly
|
| 9 |
+
- fuzzy_match(): soft string matching for diagnosis comparison
|
| 10 |
+
"""
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import asyncio
|
| 14 |
+
import json
|
| 15 |
+
import re
|
| 16 |
+
import time
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from datetime import datetime, timezone
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any, Dict, List, Optional
|
| 21 |
+
|
| 22 |
+
# ββ CDS pipeline imports ββ
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
# Ensure the backend app is importable
|
| 26 |
+
BACKEND_DIR = Path(__file__).resolve().parent.parent
|
| 27 |
+
if str(BACKEND_DIR) not in sys.path:
|
| 28 |
+
sys.path.insert(0, str(BACKEND_DIR))
|
| 29 |
+
|
| 30 |
+
from app.agent.orchestrator import Orchestrator
|
| 31 |
+
from app.models.schemas import CaseSubmission, CDSReport, AgentState
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
# Data classes
|
| 36 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class ValidationCase:
|
| 40 |
+
"""A single validation test case."""
|
| 41 |
+
case_id: str
|
| 42 |
+
source_dataset: str # "medqa", "mtsamples", "pmc"
|
| 43 |
+
input_text: str # Clinical text fed to the pipeline
|
| 44 |
+
ground_truth: Dict[str, Any] # Dataset-specific ground truth
|
| 45 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class ValidationResult:
|
| 50 |
+
"""Result of running one case through the pipeline + scoring."""
|
| 51 |
+
case_id: str
|
| 52 |
+
source_dataset: str
|
| 53 |
+
success: bool # Pipeline completed without crash
|
| 54 |
+
scores: Dict[str, float] # Metric name β score (0.0β1.0)
|
| 55 |
+
pipeline_time_ms: int = 0
|
| 56 |
+
step_results: Dict[str, str] = field(default_factory=dict) # step_id β status
|
| 57 |
+
report_summary: Optional[str] = None
|
| 58 |
+
error: Optional[str] = None
|
| 59 |
+
details: Dict[str, Any] = field(default_factory=dict) # Extra scoring info
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class ValidationSummary:
|
| 64 |
+
"""Aggregate metrics for a dataset validation run."""
|
| 65 |
+
dataset: str
|
| 66 |
+
total_cases: int
|
| 67 |
+
successful_cases: int
|
| 68 |
+
failed_cases: int
|
| 69 |
+
metrics: Dict[str, float] # Metric name β average score
|
| 70 |
+
per_case: List[ValidationResult]
|
| 71 |
+
run_duration_sec: float
|
| 72 |
+
timestamp: str = ""
|
| 73 |
+
|
| 74 |
+
def __post_init__(self):
|
| 75 |
+
if not self.timestamp:
|
| 76 |
+
self.timestamp = datetime.now(timezone.utc).isoformat()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
+
# Pipeline runner
|
| 81 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
|
| 83 |
+
async def run_cds_pipeline(
|
| 84 |
+
patient_text: str,
|
| 85 |
+
include_drug_check: bool = True,
|
| 86 |
+
include_guidelines: bool = True,
|
| 87 |
+
timeout_sec: int = 180,
|
| 88 |
+
) -> tuple[Optional[AgentState], Optional[CDSReport], Optional[str]]:
|
| 89 |
+
"""
|
| 90 |
+
Run a single case through the CDS pipeline directly (no HTTP server needed).
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
(state, report, error) β error is None on success
|
| 94 |
+
"""
|
| 95 |
+
case = CaseSubmission(
|
| 96 |
+
patient_text=patient_text,
|
| 97 |
+
include_drug_check=include_drug_check,
|
| 98 |
+
include_guidelines=include_guidelines,
|
| 99 |
+
)
|
| 100 |
+
orchestrator = Orchestrator()
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
async for _step_update in orchestrator.run(case):
|
| 104 |
+
pass # consume all step updates
|
| 105 |
+
|
| 106 |
+
return orchestrator.state, orchestrator.get_result(), None
|
| 107 |
+
except asyncio.TimeoutError:
|
| 108 |
+
return orchestrator.state, None, f"Pipeline timed out after {timeout_sec}s"
|
| 109 |
+
except Exception as e:
|
| 110 |
+
return orchestrator.state, None, str(e)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 114 |
+
# Fuzzy string matching for diagnosis comparison
|
| 115 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 116 |
+
|
| 117 |
+
def normalize_text(text: str) -> str:
|
| 118 |
+
"""Lowercase, strip punctuation, normalize whitespace."""
|
| 119 |
+
text = text.lower().strip()
|
| 120 |
+
text = re.sub(r'[^\w\s]', ' ', text)
|
| 121 |
+
text = re.sub(r'\s+', ' ', text)
|
| 122 |
+
return text
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def fuzzy_match(candidate: str, target: str, threshold: float = 0.6) -> bool:
|
| 126 |
+
"""
|
| 127 |
+
Check if candidate text is a fuzzy match for target.
|
| 128 |
+
|
| 129 |
+
Uses token overlap (Jaccard-like) rather than edit distance β
|
| 130 |
+
medical terms are long and we care about semantic overlap, not typos.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
candidate: Text from the pipeline output
|
| 134 |
+
target: Ground truth text
|
| 135 |
+
threshold: Minimum token overlap ratio (0.0β1.0)
|
| 136 |
+
"""
|
| 137 |
+
c_tokens = set(normalize_text(candidate).split())
|
| 138 |
+
t_tokens = set(normalize_text(target).split())
|
| 139 |
+
|
| 140 |
+
if not t_tokens:
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
# If target is a substring of candidate (or vice versa), that's a match
|
| 144 |
+
if normalize_text(target) in normalize_text(candidate):
|
| 145 |
+
return True
|
| 146 |
+
if normalize_text(candidate) in normalize_text(target):
|
| 147 |
+
return True
|
| 148 |
+
|
| 149 |
+
# Token overlap
|
| 150 |
+
overlap = len(c_tokens & t_tokens)
|
| 151 |
+
denominator = min(len(c_tokens), len(t_tokens))
|
| 152 |
+
if denominator == 0:
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
return (overlap / denominator) >= threshold
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def diagnosis_in_differential(
|
| 159 |
+
target_diagnosis: str,
|
| 160 |
+
report: CDSReport,
|
| 161 |
+
top_n: Optional[int] = None,
|
| 162 |
+
) -> tuple[bool, int]:
|
| 163 |
+
"""
|
| 164 |
+
Check if target_diagnosis appears in the report's differential.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
(found, rank) β rank is 0-indexed position, or -1 if not found
|
| 168 |
+
"""
|
| 169 |
+
diagnoses = report.differential_diagnosis
|
| 170 |
+
if top_n:
|
| 171 |
+
diagnoses = diagnoses[:top_n]
|
| 172 |
+
|
| 173 |
+
for i, dx in enumerate(diagnoses):
|
| 174 |
+
if fuzzy_match(dx.diagnosis, target_diagnosis):
|
| 175 |
+
return True, i
|
| 176 |
+
|
| 177 |
+
# Also check the full report text (patient_summary, guideline_recommendations, etc.)
|
| 178 |
+
full_text = " ".join([
|
| 179 |
+
report.patient_summary or "",
|
| 180 |
+
" ".join(report.guideline_recommendations),
|
| 181 |
+
" ".join(a.action for a in report.suggested_next_steps),
|
| 182 |
+
])
|
| 183 |
+
if fuzzy_match(full_text, target_diagnosis, threshold=0.3):
|
| 184 |
+
return True, len(diagnoses) # found but not in differential
|
| 185 |
+
|
| 186 |
+
return False, -1
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 190 |
+
# I/O utilities
|
| 191 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 192 |
+
|
| 193 |
+
DATA_DIR = Path(__file__).resolve().parent / "data"
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def ensure_data_dir():
|
| 197 |
+
"""Create the data directory if it doesn't exist."""
|
| 198 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def save_results(summary: ValidationSummary, filename: str = None):
|
| 202 |
+
"""Save validation results to JSON."""
|
| 203 |
+
results_dir = Path(__file__).resolve().parent / "results"
|
| 204 |
+
results_dir.mkdir(parents=True, exist_ok=True)
|
| 205 |
+
|
| 206 |
+
if filename is None:
|
| 207 |
+
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
| 208 |
+
filename = f"{summary.dataset}_{ts}.json"
|
| 209 |
+
|
| 210 |
+
path = results_dir / filename
|
| 211 |
+
|
| 212 |
+
# Convert to serializable dict
|
| 213 |
+
data = {
|
| 214 |
+
"dataset": summary.dataset,
|
| 215 |
+
"total_cases": summary.total_cases,
|
| 216 |
+
"successful_cases": summary.successful_cases,
|
| 217 |
+
"failed_cases": summary.failed_cases,
|
| 218 |
+
"metrics": summary.metrics,
|
| 219 |
+
"run_duration_sec": summary.run_duration_sec,
|
| 220 |
+
"timestamp": summary.timestamp,
|
| 221 |
+
"per_case": [
|
| 222 |
+
{
|
| 223 |
+
"case_id": r.case_id,
|
| 224 |
+
"success": r.success,
|
| 225 |
+
"scores": r.scores,
|
| 226 |
+
"pipeline_time_ms": r.pipeline_time_ms,
|
| 227 |
+
"step_results": r.step_results,
|
| 228 |
+
"report_summary": r.report_summary,
|
| 229 |
+
"error": r.error,
|
| 230 |
+
"details": r.details,
|
| 231 |
+
}
|
| 232 |
+
for r in summary.per_case
|
| 233 |
+
],
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
path.write_text(json.dumps(data, indent=2, default=str))
|
| 237 |
+
return path
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def print_summary(summary: ValidationSummary):
|
| 241 |
+
"""Pretty-print validation results to console."""
|
| 242 |
+
print(f"\n{'='*60}")
|
| 243 |
+
print(f" Validation Results: {summary.dataset.upper()}")
|
| 244 |
+
print(f"{'='*60}")
|
| 245 |
+
print(f" Total cases: {summary.total_cases}")
|
| 246 |
+
print(f" Successful: {summary.successful_cases}")
|
| 247 |
+
print(f" Failed: {summary.failed_cases}")
|
| 248 |
+
print(f" Duration: {summary.run_duration_sec:.1f}s")
|
| 249 |
+
print(f"\n Metrics:")
|
| 250 |
+
for metric, value in sorted(summary.metrics.items()):
|
| 251 |
+
if "time" in metric and isinstance(value, (int, float)):
|
| 252 |
+
print(f" {metric:30s} {value:.0f}ms")
|
| 253 |
+
elif isinstance(value, float):
|
| 254 |
+
print(f" {metric:30s} {value:.1%}")
|
| 255 |
+
else:
|
| 256 |
+
print(f" {metric:30s} {value}")
|
| 257 |
+
print(f"{'='*60}\n")
|
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MedQA dataset fetcher and validation harness.
|
| 3 |
+
|
| 4 |
+
Downloads MedQA USMLE 4-option questions and evaluates the CDS pipeline's
|
| 5 |
+
ability to arrive at the correct diagnosis / answer.
|
| 6 |
+
|
| 7 |
+
Source: https://github.com/jind11/MedQA
|
| 8 |
+
Format: JSONL with {question, options: {A, B, C, D}, answer_idx, answer}
|
| 9 |
+
|
| 10 |
+
Metrics:
|
| 11 |
+
- top1_accuracy: Correct answer matches #1 differential diagnosis
|
| 12 |
+
- top3_accuracy: Correct answer in top 3 differential diagnoses
|
| 13 |
+
- mentioned_accuracy: Correct answer mentioned anywhere in report
|
| 14 |
+
- parse_success_rate: Pipeline completed without crashing
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import asyncio
|
| 19 |
+
import json
|
| 20 |
+
import random
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import List, Optional
|
| 24 |
+
|
| 25 |
+
import httpx
|
| 26 |
+
|
| 27 |
+
from validation.base import (
|
| 28 |
+
DATA_DIR,
|
| 29 |
+
ValidationCase,
|
| 30 |
+
ValidationResult,
|
| 31 |
+
ValidationSummary,
|
| 32 |
+
diagnosis_in_differential,
|
| 33 |
+
ensure_data_dir,
|
| 34 |
+
fuzzy_match,
|
| 35 |
+
normalize_text,
|
| 36 |
+
print_summary,
|
| 37 |
+
run_cds_pipeline,
|
| 38 |
+
save_results,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
# Data fetching
|
| 44 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
|
| 46 |
+
# HuggingFace direct download (JSONL)
|
| 47 |
+
MEDQA_JSONL_URL = "https://huggingface.co/datasets/GBaker/MedQA-USMLE-4-options/resolve/main/phrases_no_exclude_test.jsonl"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
async def fetch_medqa(max_cases: int = 50, seed: int = 42) -> List[ValidationCase]:
|
| 51 |
+
"""
|
| 52 |
+
Download MedQA test set and convert to ValidationCase objects.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
max_cases: Maximum number of cases to sample
|
| 56 |
+
seed: Random seed for reproducible sampling
|
| 57 |
+
"""
|
| 58 |
+
ensure_data_dir()
|
| 59 |
+
cache_path = DATA_DIR / "medqa_test.jsonl"
|
| 60 |
+
|
| 61 |
+
# Try to load from cache
|
| 62 |
+
if cache_path.exists():
|
| 63 |
+
print(f" Loading MedQA from cache: {cache_path}")
|
| 64 |
+
raw_cases = _load_jsonl(cache_path)
|
| 65 |
+
else:
|
| 66 |
+
print(f" Downloading MedQA test set...")
|
| 67 |
+
raw_cases = await _download_medqa_jsonl(cache_path)
|
| 68 |
+
|
| 69 |
+
if not raw_cases:
|
| 70 |
+
raise RuntimeError("Failed to fetch MedQA data. Check network connection.")
|
| 71 |
+
|
| 72 |
+
# Sample
|
| 73 |
+
random.seed(seed)
|
| 74 |
+
if len(raw_cases) > max_cases:
|
| 75 |
+
raw_cases = random.sample(raw_cases, max_cases)
|
| 76 |
+
|
| 77 |
+
# Convert to ValidationCase
|
| 78 |
+
cases = []
|
| 79 |
+
for i, item in enumerate(raw_cases):
|
| 80 |
+
question = item.get("question", "")
|
| 81 |
+
options = item.get("options", item.get("answer_choices", {}))
|
| 82 |
+
answer_idx = item.get("answer_idx", item.get("answer", ""))
|
| 83 |
+
answer_text = item.get("answer", "")
|
| 84 |
+
|
| 85 |
+
# Handle different formats
|
| 86 |
+
if isinstance(options, dict):
|
| 87 |
+
if answer_idx in options:
|
| 88 |
+
answer_text = options[answer_idx]
|
| 89 |
+
elif isinstance(options, list):
|
| 90 |
+
# Some formats have options as a list
|
| 91 |
+
idx = ord(answer_idx) - ord('A') if isinstance(answer_idx, str) and len(answer_idx) == 1 else 0
|
| 92 |
+
if idx < len(options):
|
| 93 |
+
answer_text = options[idx]
|
| 94 |
+
|
| 95 |
+
# Build clinical vignette (question only, not the options)
|
| 96 |
+
# This simulates what a clinician would present
|
| 97 |
+
clinical_text = _extract_vignette(question)
|
| 98 |
+
|
| 99 |
+
cases.append(ValidationCase(
|
| 100 |
+
case_id=f"medqa_{i:04d}",
|
| 101 |
+
source_dataset="medqa",
|
| 102 |
+
input_text=clinical_text,
|
| 103 |
+
ground_truth={
|
| 104 |
+
"correct_answer": answer_text,
|
| 105 |
+
"answer_idx": answer_idx,
|
| 106 |
+
"options": options,
|
| 107 |
+
"full_question": question,
|
| 108 |
+
},
|
| 109 |
+
))
|
| 110 |
+
|
| 111 |
+
print(f" Loaded {len(cases)} MedQA cases")
|
| 112 |
+
return cases
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
async def _download_medqa_jsonl(cache_path: Path) -> List[dict]:
|
| 116 |
+
"""Download MedQA JSONL from GitHub."""
|
| 117 |
+
async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
|
| 118 |
+
try:
|
| 119 |
+
r = await client.get(MEDQA_JSONL_URL)
|
| 120 |
+
r.raise_for_status()
|
| 121 |
+
|
| 122 |
+
lines = r.text.strip().split('\n')
|
| 123 |
+
cases = [json.loads(line) for line in lines if line.strip()]
|
| 124 |
+
|
| 125 |
+
# Cache
|
| 126 |
+
cache_path.write_text('\n'.join(json.dumps(c) for c in cases))
|
| 127 |
+
print(f" Cached {len(cases)} MedQA cases to {cache_path}")
|
| 128 |
+
return cases
|
| 129 |
+
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f" Warning: Failed to download MedQA: {e}")
|
| 132 |
+
return []
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _load_jsonl(path: Path) -> List[dict]:
|
| 136 |
+
"""Load JSONL file."""
|
| 137 |
+
cases = []
|
| 138 |
+
for line in path.read_text(encoding="utf-8").strip().split('\n'):
|
| 139 |
+
if line.strip():
|
| 140 |
+
cases.append(json.loads(line))
|
| 141 |
+
return cases
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _extract_vignette(question: str) -> str:
|
| 145 |
+
"""
|
| 146 |
+
Extract the clinical vignette from a USMLE question.
|
| 147 |
+
|
| 148 |
+
USMLE questions typically end with "Which of the following..." or
|
| 149 |
+
"What is the most likely diagnosis?". We strip the question stem
|
| 150 |
+
to leave just the clinical narrative.
|
| 151 |
+
"""
|
| 152 |
+
# Common question stems
|
| 153 |
+
stems = [
|
| 154 |
+
r"which of the following",
|
| 155 |
+
r"what is the most likely",
|
| 156 |
+
r"what is the best next step",
|
| 157 |
+
r"what is the most appropriate",
|
| 158 |
+
r"what is the diagnosis",
|
| 159 |
+
r"the most likely diagnosis is",
|
| 160 |
+
r"this patient most likely has",
|
| 161 |
+
r"what would be the next step",
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
text = question.strip()
|
| 165 |
+
for stem in stems:
|
| 166 |
+
import re
|
| 167 |
+
# Find the last sentence that starts a question
|
| 168 |
+
pattern = re.compile(rf'\.?\s*[A-Z].*{stem}.*[\?\.]?\s*$', re.IGNORECASE)
|
| 169 |
+
match = pattern.search(text)
|
| 170 |
+
if match:
|
| 171 |
+
# Return everything before the question stem sentence
|
| 172 |
+
vignette = text[:match.start()].strip()
|
| 173 |
+
if len(vignette) > 50: # Sanity check
|
| 174 |
+
return vignette
|
| 175 |
+
|
| 176 |
+
# Fallback: return the full text
|
| 177 |
+
return text
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 181 |
+
# Validation harness
|
| 182 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 183 |
+
|
| 184 |
+
async def validate_medqa(
|
| 185 |
+
cases: List[ValidationCase],
|
| 186 |
+
include_drug_check: bool = False,
|
| 187 |
+
include_guidelines: bool = True,
|
| 188 |
+
delay_between_cases: float = 2.0,
|
| 189 |
+
) -> ValidationSummary:
|
| 190 |
+
"""
|
| 191 |
+
Run MedQA cases through the CDS pipeline and score results.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
cases: List of MedQA ValidationCases
|
| 195 |
+
include_drug_check: Whether to run drug interaction check (slower)
|
| 196 |
+
include_guidelines: Whether to include guideline retrieval
|
| 197 |
+
delay_between_cases: Seconds to wait between cases (rate limiting)
|
| 198 |
+
"""
|
| 199 |
+
results: List[ValidationResult] = []
|
| 200 |
+
start_time = time.time()
|
| 201 |
+
|
| 202 |
+
for i, case in enumerate(cases):
|
| 203 |
+
print(f"\n [{i+1}/{len(cases)}] {case.case_id}: ", end="", flush=True)
|
| 204 |
+
|
| 205 |
+
case_start = time.monotonic()
|
| 206 |
+
|
| 207 |
+
state, report, error = await run_cds_pipeline(
|
| 208 |
+
patient_text=case.input_text,
|
| 209 |
+
include_drug_check=include_drug_check,
|
| 210 |
+
include_guidelines=include_guidelines,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
elapsed_ms = int((time.monotonic() - case_start) * 1000)
|
| 214 |
+
|
| 215 |
+
# Build step results
|
| 216 |
+
step_results = {}
|
| 217 |
+
if state:
|
| 218 |
+
step_results = {s.step_id: s.status.value for s in state.steps}
|
| 219 |
+
|
| 220 |
+
# Score
|
| 221 |
+
scores = {}
|
| 222 |
+
details = {}
|
| 223 |
+
correct_answer = case.ground_truth["correct_answer"]
|
| 224 |
+
|
| 225 |
+
if report:
|
| 226 |
+
# Top-1 accuracy
|
| 227 |
+
found_top1, rank = diagnosis_in_differential(correct_answer, report, top_n=1)
|
| 228 |
+
scores["top1_accuracy"] = 1.0 if found_top1 else 0.0
|
| 229 |
+
|
| 230 |
+
# Top-3 accuracy
|
| 231 |
+
found_top3, rank3 = diagnosis_in_differential(correct_answer, report, top_n=3)
|
| 232 |
+
scores["top3_accuracy"] = 1.0 if found_top3 else 0.0
|
| 233 |
+
|
| 234 |
+
# Mentioned anywhere
|
| 235 |
+
found_any, rank_any = diagnosis_in_differential(correct_answer, report)
|
| 236 |
+
scores["mentioned_accuracy"] = 1.0 if found_any else 0.0
|
| 237 |
+
|
| 238 |
+
# Parse success
|
| 239 |
+
scores["parse_success"] = 1.0
|
| 240 |
+
|
| 241 |
+
details = {
|
| 242 |
+
"correct_answer": correct_answer,
|
| 243 |
+
"top_diagnosis": report.differential_diagnosis[0].diagnosis if report.differential_diagnosis else "NONE",
|
| 244 |
+
"num_diagnoses": len(report.differential_diagnosis),
|
| 245 |
+
"found_at_rank": rank_any if found_any else -1,
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
status_icon = "β" if found_top3 else "β"
|
| 249 |
+
print(f"{status_icon} top1={'Y' if found_top1 else 'N'} top3={'Y' if found_top3 else 'N'} ({elapsed_ms}ms)")
|
| 250 |
+
else:
|
| 251 |
+
scores = {
|
| 252 |
+
"top1_accuracy": 0.0,
|
| 253 |
+
"top3_accuracy": 0.0,
|
| 254 |
+
"mentioned_accuracy": 0.0,
|
| 255 |
+
"parse_success": 0.0,
|
| 256 |
+
}
|
| 257 |
+
details = {"correct_answer": correct_answer, "error": error}
|
| 258 |
+
print(f"β FAILED: {error[:80] if error else 'unknown'}")
|
| 259 |
+
|
| 260 |
+
results.append(ValidationResult(
|
| 261 |
+
case_id=case.case_id,
|
| 262 |
+
source_dataset="medqa",
|
| 263 |
+
success=report is not None,
|
| 264 |
+
scores=scores,
|
| 265 |
+
pipeline_time_ms=elapsed_ms,
|
| 266 |
+
step_results=step_results,
|
| 267 |
+
report_summary=report.patient_summary[:200] if report else None,
|
| 268 |
+
error=error,
|
| 269 |
+
details=details,
|
| 270 |
+
))
|
| 271 |
+
|
| 272 |
+
# Rate limit
|
| 273 |
+
if i < len(cases) - 1:
|
| 274 |
+
await asyncio.sleep(delay_between_cases)
|
| 275 |
+
|
| 276 |
+
# Aggregate
|
| 277 |
+
total = len(results)
|
| 278 |
+
successful = sum(1 for r in results if r.success)
|
| 279 |
+
|
| 280 |
+
# Average each metric across successful cases only
|
| 281 |
+
metric_names = ["top1_accuracy", "top3_accuracy", "mentioned_accuracy", "parse_success"]
|
| 282 |
+
metrics = {}
|
| 283 |
+
for m in metric_names:
|
| 284 |
+
values = [r.scores.get(m, 0.0) for r in results]
|
| 285 |
+
metrics[m] = sum(values) / len(values) if values else 0.0
|
| 286 |
+
|
| 287 |
+
# Average pipeline time
|
| 288 |
+
times = [r.pipeline_time_ms for r in results if r.success]
|
| 289 |
+
metrics["avg_pipeline_time_ms"] = sum(times) / len(times) if times else 0
|
| 290 |
+
|
| 291 |
+
summary = ValidationSummary(
|
| 292 |
+
dataset="medqa",
|
| 293 |
+
total_cases=total,
|
| 294 |
+
successful_cases=successful,
|
| 295 |
+
failed_cases=total - successful,
|
| 296 |
+
metrics=metrics,
|
| 297 |
+
per_case=results,
|
| 298 |
+
run_duration_sec=time.time() - start_time,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
return summary
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 305 |
+
# Standalone runner
|
| 306 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 307 |
+
|
| 308 |
+
async def main():
|
| 309 |
+
"""Run MedQA validation standalone."""
|
| 310 |
+
import argparse
|
| 311 |
+
|
| 312 |
+
parser = argparse.ArgumentParser(description="MedQA Validation")
|
| 313 |
+
parser.add_argument("--max-cases", type=int, default=10, help="Number of cases to evaluate")
|
| 314 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 315 |
+
parser.add_argument("--include-drugs", action="store_true", help="Include drug interaction check")
|
| 316 |
+
parser.add_argument("--delay", type=float, default=2.0, help="Delay between cases (seconds)")
|
| 317 |
+
args = parser.parse_args()
|
| 318 |
+
|
| 319 |
+
print("MedQA Validation Harness")
|
| 320 |
+
print("=" * 40)
|
| 321 |
+
|
| 322 |
+
cases = await fetch_medqa(max_cases=args.max_cases, seed=args.seed)
|
| 323 |
+
summary = await validate_medqa(
|
| 324 |
+
cases,
|
| 325 |
+
include_drug_check=args.include_drugs,
|
| 326 |
+
delay_between_cases=args.delay,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
print_summary(summary)
|
| 330 |
+
path = save_results(summary)
|
| 331 |
+
print(f"Results saved to: {path}")
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
if __name__ == "__main__":
|
| 335 |
+
asyncio.run(main())
|
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MTSamples dataset fetcher and validation harness.
|
| 3 |
+
|
| 4 |
+
Downloads medical transcription samples and evaluates the CDS pipeline's
|
| 5 |
+
ability to parse diverse clinical note formats and reason across specialties.
|
| 6 |
+
|
| 7 |
+
Source: https://mtsamples.com (via GitHub mirrors)
|
| 8 |
+
Format: CSV with columns: description, medical_specialty, sample_name, transcription, keywords
|
| 9 |
+
|
| 10 |
+
Metrics:
|
| 11 |
+
- parse_success_rate: Pipeline completed without crashing
|
| 12 |
+
- field_completeness: How many structured fields were extracted
|
| 13 |
+
- specialty_alignment: System reasoning aligns with correct specialty
|
| 14 |
+
- has_differential: Report includes at least one diagnosis
|
| 15 |
+
- has_recommendations: Report includes next steps
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import asyncio
|
| 20 |
+
import csv
|
| 21 |
+
import io
|
| 22 |
+
import json
|
| 23 |
+
import random
|
| 24 |
+
import re
|
| 25 |
+
import time
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import List, Optional
|
| 28 |
+
|
| 29 |
+
import httpx
|
| 30 |
+
|
| 31 |
+
from validation.base import (
|
| 32 |
+
DATA_DIR,
|
| 33 |
+
ValidationCase,
|
| 34 |
+
ValidationResult,
|
| 35 |
+
ValidationSummary,
|
| 36 |
+
ensure_data_dir,
|
| 37 |
+
fuzzy_match,
|
| 38 |
+
normalize_text,
|
| 39 |
+
print_summary,
|
| 40 |
+
run_cds_pipeline,
|
| 41 |
+
save_results,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
# Data fetching
|
| 47 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
|
| 49 |
+
MTSAMPLES_URL = "https://raw.githubusercontent.com/socd06/medical-nlp/master/data/mtsamples.csv"
|
| 50 |
+
MTSAMPLES_FALLBACK_URL = "https://raw.githubusercontent.com/Abonia1/Clinical-NLP-on-MTSamples/master/mtsamples.csv"
|
| 51 |
+
|
| 52 |
+
# Specialties most relevant to CDS
|
| 53 |
+
RELEVANT_SPECIALTIES = {
|
| 54 |
+
"Cardiovascular / Pulmonary",
|
| 55 |
+
"Gastroenterology",
|
| 56 |
+
"General Medicine",
|
| 57 |
+
"Neurology",
|
| 58 |
+
"Orthopedic",
|
| 59 |
+
"Urology",
|
| 60 |
+
"Nephrology",
|
| 61 |
+
"Endocrinology",
|
| 62 |
+
"Hematology - Oncology",
|
| 63 |
+
"Obstetrics / Gynecology",
|
| 64 |
+
"Emergency Room Reports",
|
| 65 |
+
"Consult - History and Phy.",
|
| 66 |
+
"Discharge Summary",
|
| 67 |
+
"SOAP / Chart / Progress Notes",
|
| 68 |
+
"Internal Medicine",
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def fetch_mtsamples(
|
| 73 |
+
max_cases: int = 30,
|
| 74 |
+
seed: int = 42,
|
| 75 |
+
specialties: Optional[set] = None,
|
| 76 |
+
min_length: int = 200,
|
| 77 |
+
) -> List[ValidationCase]:
|
| 78 |
+
"""
|
| 79 |
+
Download MTSamples and convert to ValidationCase objects.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
max_cases: Maximum number of cases to sample
|
| 83 |
+
seed: Random seed for reproducible sampling
|
| 84 |
+
specialties: Filter to these specialties (None = use RELEVANT_SPECIALTIES)
|
| 85 |
+
min_length: Minimum transcription length to include
|
| 86 |
+
"""
|
| 87 |
+
ensure_data_dir()
|
| 88 |
+
cache_path = DATA_DIR / "mtsamples.csv"
|
| 89 |
+
|
| 90 |
+
if cache_path.exists():
|
| 91 |
+
print(f" Loading MTSamples from cache: {cache_path}")
|
| 92 |
+
raw_text = cache_path.read_text(encoding="utf-8")
|
| 93 |
+
else:
|
| 94 |
+
print(f" Downloading MTSamples...")
|
| 95 |
+
raw_text = await _download_mtsamples(cache_path)
|
| 96 |
+
|
| 97 |
+
if not raw_text:
|
| 98 |
+
raise RuntimeError("Failed to fetch MTSamples data.")
|
| 99 |
+
|
| 100 |
+
# Parse CSV
|
| 101 |
+
reader = csv.DictReader(io.StringIO(raw_text))
|
| 102 |
+
rows = list(reader)
|
| 103 |
+
|
| 104 |
+
# Filter
|
| 105 |
+
target_specialties = specialties or RELEVANT_SPECIALTIES
|
| 106 |
+
filtered = []
|
| 107 |
+
for row in rows:
|
| 108 |
+
specialty = row.get("medical_specialty", "").strip()
|
| 109 |
+
transcription = row.get("transcription", "").strip()
|
| 110 |
+
if not transcription or len(transcription) < min_length:
|
| 111 |
+
continue
|
| 112 |
+
if specialty in target_specialties:
|
| 113 |
+
filtered.append(row)
|
| 114 |
+
|
| 115 |
+
# Sample
|
| 116 |
+
random.seed(seed)
|
| 117 |
+
if len(filtered) > max_cases:
|
| 118 |
+
# Stratified sample: try to get cases from diverse specialties
|
| 119 |
+
by_specialty = {}
|
| 120 |
+
for row in filtered:
|
| 121 |
+
sp = row.get("medical_specialty", "Other")
|
| 122 |
+
by_specialty.setdefault(sp, []).append(row)
|
| 123 |
+
|
| 124 |
+
sampled = []
|
| 125 |
+
per_specialty = max(1, max_cases // len(by_specialty))
|
| 126 |
+
for sp, sp_rows in by_specialty.items():
|
| 127 |
+
sampled.extend(random.sample(sp_rows, min(per_specialty, len(sp_rows))))
|
| 128 |
+
|
| 129 |
+
# Fill remaining slots randomly
|
| 130 |
+
remaining = [r for r in filtered if r not in sampled]
|
| 131 |
+
if len(sampled) < max_cases and remaining:
|
| 132 |
+
sampled.extend(random.sample(remaining, min(max_cases - len(sampled), len(remaining))))
|
| 133 |
+
|
| 134 |
+
filtered = sampled[:max_cases]
|
| 135 |
+
|
| 136 |
+
# Convert to ValidationCase
|
| 137 |
+
cases = []
|
| 138 |
+
for i, row in enumerate(filtered):
|
| 139 |
+
transcription = row.get("transcription", "").strip()
|
| 140 |
+
specialty = row.get("medical_specialty", "Unknown").strip()
|
| 141 |
+
description = row.get("description", "").strip()
|
| 142 |
+
keywords = row.get("keywords", "").strip()
|
| 143 |
+
|
| 144 |
+
cases.append(ValidationCase(
|
| 145 |
+
case_id=f"mts_{i:04d}",
|
| 146 |
+
source_dataset="mtsamples",
|
| 147 |
+
input_text=transcription,
|
| 148 |
+
ground_truth={
|
| 149 |
+
"specialty": specialty,
|
| 150 |
+
"description": description,
|
| 151 |
+
"keywords": keywords,
|
| 152 |
+
},
|
| 153 |
+
metadata={
|
| 154 |
+
"sample_name": row.get("sample_name", ""),
|
| 155 |
+
"text_length": len(transcription),
|
| 156 |
+
},
|
| 157 |
+
))
|
| 158 |
+
|
| 159 |
+
print(f" Loaded {len(cases)} MTSamples cases across {len(set(c.ground_truth['specialty'] for c in cases))} specialties")
|
| 160 |
+
return cases
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
async def _download_mtsamples(cache_path: Path) -> str:
|
| 164 |
+
"""Download MTSamples CSV."""
|
| 165 |
+
async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
|
| 166 |
+
for url in [MTSAMPLES_URL, MTSAMPLES_FALLBACK_URL]:
|
| 167 |
+
try:
|
| 168 |
+
r = await client.get(url)
|
| 169 |
+
r.raise_for_status()
|
| 170 |
+
cache_path.write_text(r.text, encoding="utf-8")
|
| 171 |
+
print(f" Cached MTSamples ({len(r.text)} bytes) to {cache_path}")
|
| 172 |
+
return r.text
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f" Warning: Failed to download from {url}: {e}")
|
| 175 |
+
continue
|
| 176 |
+
return ""
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 180 |
+
# Scoring helpers
|
| 181 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 182 |
+
|
| 183 |
+
SPECIALTY_KEYWORDS = {
|
| 184 |
+
"Cardiovascular / Pulmonary": ["cardiac", "heart", "coronary", "pulmonary", "lung", "chest", "hypertension", "arrhythmia"],
|
| 185 |
+
"Gastroenterology": ["gastro", "liver", "hepat", "colon", "bowel", "gi ", "abdominal", "pancrea"],
|
| 186 |
+
"General Medicine": ["general", "medicine", "primary", "routine"],
|
| 187 |
+
"Neurology": ["neuro", "brain", "seizure", "stroke", "headache", "neuropathy", "ms "],
|
| 188 |
+
"Orthopedic": ["ortho", "fracture", "bone", "joint", "knee", "hip", "shoulder", "spine"],
|
| 189 |
+
"Urology": ["urol", "kidney", "bladder", "prostate", "renal", "urinary"],
|
| 190 |
+
"Nephrology": ["renal", "kidney", "dialysis", "nephr", "creatinine"],
|
| 191 |
+
"Endocrinology": ["diabet", "thyroid", "endocrin", "insulin", "glucose", "adrenal"],
|
| 192 |
+
"Hematology - Oncology": ["cancer", "tumor", "leukemia", "lymphoma", "anemia", "oncol"],
|
| 193 |
+
"Obstetrics / Gynecology": ["pregnan", "obstet", "gynecol", "uterus", "ovarian", "menstrual"],
|
| 194 |
+
"Emergency Room Reports": ["emergency", "trauma", "acute", "er ", "ed "],
|
| 195 |
+
"Internal Medicine": ["internal", "medicine"],
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def check_specialty_alignment(report_text: str, target_specialty: str) -> bool:
|
| 200 |
+
"""Check if the report's content aligns with the expected specialty."""
|
| 201 |
+
keywords = SPECIALTY_KEYWORDS.get(target_specialty, [])
|
| 202 |
+
if not keywords:
|
| 203 |
+
return True # Can't check, assume aligned
|
| 204 |
+
|
| 205 |
+
report_lower = report_text.lower()
|
| 206 |
+
matches = sum(1 for kw in keywords if kw in report_lower)
|
| 207 |
+
return matches >= 1 # At least one specialty keyword present
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def score_field_completeness(state) -> float:
|
| 211 |
+
"""Score how many structured fields were successfully extracted from parsing."""
|
| 212 |
+
if not state or not state.patient_profile:
|
| 213 |
+
return 0.0
|
| 214 |
+
|
| 215 |
+
profile = state.patient_profile
|
| 216 |
+
fields = [
|
| 217 |
+
profile.age is not None,
|
| 218 |
+
profile.gender.value != "unknown",
|
| 219 |
+
bool(profile.chief_complaint),
|
| 220 |
+
bool(profile.history_of_present_illness),
|
| 221 |
+
len(profile.past_medical_history) > 0,
|
| 222 |
+
len(profile.current_medications) > 0,
|
| 223 |
+
len(profile.allergies) > 0,
|
| 224 |
+
len(profile.lab_results) > 0,
|
| 225 |
+
profile.vital_signs is not None,
|
| 226 |
+
bool(profile.social_history),
|
| 227 |
+
bool(profile.family_history),
|
| 228 |
+
]
|
| 229 |
+
return sum(fields) / len(fields)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 233 |
+
# Validation harness
|
| 234 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 235 |
+
|
| 236 |
+
async def validate_mtsamples(
|
| 237 |
+
cases: List[ValidationCase],
|
| 238 |
+
include_drug_check: bool = True,
|
| 239 |
+
include_guidelines: bool = True,
|
| 240 |
+
delay_between_cases: float = 2.0,
|
| 241 |
+
) -> ValidationSummary:
|
| 242 |
+
"""
|
| 243 |
+
Run MTSamples cases through the CDS pipeline and score results.
|
| 244 |
+
"""
|
| 245 |
+
results: List[ValidationResult] = []
|
| 246 |
+
start_time = time.time()
|
| 247 |
+
|
| 248 |
+
for i, case in enumerate(cases):
|
| 249 |
+
specialty = case.ground_truth.get("specialty", "?")
|
| 250 |
+
print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty}): ", end="", flush=True)
|
| 251 |
+
|
| 252 |
+
case_start = time.monotonic()
|
| 253 |
+
|
| 254 |
+
state, report, error = await run_cds_pipeline(
|
| 255 |
+
patient_text=case.input_text,
|
| 256 |
+
include_drug_check=include_drug_check,
|
| 257 |
+
include_guidelines=include_guidelines,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
elapsed_ms = int((time.monotonic() - case_start) * 1000)
|
| 261 |
+
|
| 262 |
+
# Step results
|
| 263 |
+
step_results = {}
|
| 264 |
+
if state:
|
| 265 |
+
step_results = {s.step_id: s.status.value for s in state.steps}
|
| 266 |
+
|
| 267 |
+
# Score
|
| 268 |
+
scores = {}
|
| 269 |
+
details = {}
|
| 270 |
+
|
| 271 |
+
# Parse success
|
| 272 |
+
scores["parse_success"] = 1.0 if (state and state.patient_profile) else 0.0
|
| 273 |
+
|
| 274 |
+
# Field completeness
|
| 275 |
+
scores["field_completeness"] = score_field_completeness(state)
|
| 276 |
+
|
| 277 |
+
if report:
|
| 278 |
+
# Has differential
|
| 279 |
+
scores["has_differential"] = 1.0 if len(report.differential_diagnosis) > 0 else 0.0
|
| 280 |
+
|
| 281 |
+
# Has recommendations
|
| 282 |
+
scores["has_recommendations"] = 1.0 if len(report.suggested_next_steps) > 0 else 0.0
|
| 283 |
+
|
| 284 |
+
# Has guideline recommendations
|
| 285 |
+
scores["has_guidelines"] = 1.0 if len(report.guideline_recommendations) > 0 else 0.0
|
| 286 |
+
|
| 287 |
+
# Specialty alignment
|
| 288 |
+
full_report_text = " ".join([
|
| 289 |
+
report.patient_summary or "",
|
| 290 |
+
" ".join(d.diagnosis for d in report.differential_diagnosis),
|
| 291 |
+
" ".join(report.guideline_recommendations),
|
| 292 |
+
" ".join(a.action for a in report.suggested_next_steps),
|
| 293 |
+
])
|
| 294 |
+
scores["specialty_alignment"] = 1.0 if check_specialty_alignment(
|
| 295 |
+
full_report_text, specialty
|
| 296 |
+
) else 0.0
|
| 297 |
+
|
| 298 |
+
# Conflict detection worked (if applicable)
|
| 299 |
+
if state and state.conflict_detection:
|
| 300 |
+
scores["conflict_detection_ran"] = 1.0
|
| 301 |
+
else:
|
| 302 |
+
scores["conflict_detection_ran"] = 0.0
|
| 303 |
+
|
| 304 |
+
details = {
|
| 305 |
+
"specialty": specialty,
|
| 306 |
+
"num_diagnoses": len(report.differential_diagnosis),
|
| 307 |
+
"num_recommendations": len(report.suggested_next_steps),
|
| 308 |
+
"field_completeness": scores["field_completeness"],
|
| 309 |
+
"num_conflicts": len(report.conflicts) if report.conflicts else 0,
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
print(f"β fields={scores['field_completeness']:.0%} dx={len(report.differential_diagnosis)} ({elapsed_ms}ms)")
|
| 313 |
+
else:
|
| 314 |
+
scores.update({
|
| 315 |
+
"has_differential": 0.0,
|
| 316 |
+
"has_recommendations": 0.0,
|
| 317 |
+
"has_guidelines": 0.0,
|
| 318 |
+
"specialty_alignment": 0.0,
|
| 319 |
+
"conflict_detection_ran": 0.0,
|
| 320 |
+
})
|
| 321 |
+
details = {"specialty": specialty, "error": error}
|
| 322 |
+
print(f"β FAILED: {error[:80] if error else 'unknown'}")
|
| 323 |
+
|
| 324 |
+
results.append(ValidationResult(
|
| 325 |
+
case_id=case.case_id,
|
| 326 |
+
source_dataset="mtsamples",
|
| 327 |
+
success=report is not None,
|
| 328 |
+
scores=scores,
|
| 329 |
+
pipeline_time_ms=elapsed_ms,
|
| 330 |
+
step_results=step_results,
|
| 331 |
+
report_summary=report.patient_summary[:200] if report else None,
|
| 332 |
+
error=error,
|
| 333 |
+
details=details,
|
| 334 |
+
))
|
| 335 |
+
|
| 336 |
+
if i < len(cases) - 1:
|
| 337 |
+
await asyncio.sleep(delay_between_cases)
|
| 338 |
+
|
| 339 |
+
# Aggregate
|
| 340 |
+
total = len(results)
|
| 341 |
+
successful = sum(1 for r in results if r.success)
|
| 342 |
+
|
| 343 |
+
metric_names = [
|
| 344 |
+
"parse_success", "field_completeness", "has_differential",
|
| 345 |
+
"has_recommendations", "has_guidelines", "specialty_alignment",
|
| 346 |
+
"conflict_detection_ran",
|
| 347 |
+
]
|
| 348 |
+
metrics = {}
|
| 349 |
+
for m in metric_names:
|
| 350 |
+
values = [r.scores.get(m, 0.0) for r in results]
|
| 351 |
+
metrics[m] = sum(values) / len(values) if values else 0.0
|
| 352 |
+
|
| 353 |
+
times = [r.pipeline_time_ms for r in results if r.success]
|
| 354 |
+
metrics["avg_pipeline_time_ms"] = sum(times) / len(times) if times else 0
|
| 355 |
+
|
| 356 |
+
summary = ValidationSummary(
|
| 357 |
+
dataset="mtsamples",
|
| 358 |
+
total_cases=total,
|
| 359 |
+
successful_cases=successful,
|
| 360 |
+
failed_cases=total - successful,
|
| 361 |
+
metrics=metrics,
|
| 362 |
+
per_case=results,
|
| 363 |
+
run_duration_sec=time.time() - start_time,
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
return summary
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 370 |
+
# Standalone runner
|
| 371 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 372 |
+
|
| 373 |
+
async def main():
|
| 374 |
+
"""Run MTSamples validation standalone."""
|
| 375 |
+
import argparse
|
| 376 |
+
|
| 377 |
+
parser = argparse.ArgumentParser(description="MTSamples Validation")
|
| 378 |
+
parser.add_argument("--max-cases", type=int, default=10, help="Number of cases to evaluate")
|
| 379 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 380 |
+
parser.add_argument("--no-drugs", action="store_true", help="Skip drug interaction check")
|
| 381 |
+
parser.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
|
| 382 |
+
parser.add_argument("--delay", type=float, default=2.0, help="Delay between cases (seconds)")
|
| 383 |
+
args = parser.parse_args()
|
| 384 |
+
|
| 385 |
+
print("MTSamples Validation Harness")
|
| 386 |
+
print("=" * 40)
|
| 387 |
+
|
| 388 |
+
cases = await fetch_mtsamples(max_cases=args.max_cases, seed=args.seed)
|
| 389 |
+
summary = await validate_mtsamples(
|
| 390 |
+
cases,
|
| 391 |
+
include_drug_check=not args.no_drugs,
|
| 392 |
+
include_guidelines=not args.no_guidelines,
|
| 393 |
+
delay_between_cases=args.delay,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
print_summary(summary)
|
| 397 |
+
path = save_results(summary)
|
| 398 |
+
print(f"Results saved to: {path}")
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
if __name__ == "__main__":
|
| 402 |
+
asyncio.run(main())
|
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PMC Case Reports fetcher and validation harness.
|
| 3 |
+
|
| 4 |
+
Fetches published clinical case reports from PubMed Central and evaluates
|
| 5 |
+
the CDS pipeline's diagnostic accuracy against gold-standard diagnoses.
|
| 6 |
+
|
| 7 |
+
Source: NCBI PubMed / PubMed Central (E-utilities API)
|
| 8 |
+
Format: XML abstracts with case presentations and final diagnoses
|
| 9 |
+
|
| 10 |
+
Metrics:
|
| 11 |
+
- diagnostic_accuracy: Correct diagnosis appears in differential
|
| 12 |
+
- top3_accuracy: Correct diagnosis in top 3
|
| 13 |
+
- parse_success_rate: Pipeline completed without crashing
|
| 14 |
+
- has_recommendations: Report includes actionable next steps
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import asyncio
|
| 19 |
+
import json
|
| 20 |
+
import random
|
| 21 |
+
import re
|
| 22 |
+
import time
|
| 23 |
+
import xml.etree.ElementTree as ET
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
from typing import List, Optional, Tuple
|
| 26 |
+
|
| 27 |
+
import httpx
|
| 28 |
+
|
| 29 |
+
from validation.base import (
|
| 30 |
+
DATA_DIR,
|
| 31 |
+
ValidationCase,
|
| 32 |
+
ValidationResult,
|
| 33 |
+
ValidationSummary,
|
| 34 |
+
diagnosis_in_differential,
|
| 35 |
+
ensure_data_dir,
|
| 36 |
+
fuzzy_match,
|
| 37 |
+
normalize_text,
|
| 38 |
+
print_summary,
|
| 39 |
+
run_cds_pipeline,
|
| 40 |
+
save_results,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
# NCBI E-utilities configuration
|
| 46 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
|
| 48 |
+
EUTILS_BASE = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
| 49 |
+
ESEARCH_URL = f"{EUTILS_BASE}/esearch.fcgi"
|
| 50 |
+
EFETCH_URL = f"{EUTILS_BASE}/efetch.fcgi"
|
| 51 |
+
|
| 52 |
+
# Curated search queries for case reports with clear diagnoses
|
| 53 |
+
# Each tuple: (search_query, expected_specialty)
|
| 54 |
+
CASE_REPORT_QUERIES = [
|
| 55 |
+
('"case report"[Title] AND "myocardial infarction"[Title] AND diagnosis', "Cardiology"),
|
| 56 |
+
('"case report"[Title] AND "pneumonia"[Title] AND diagnosis', "Pulmonology"),
|
| 57 |
+
('"case report"[Title] AND "diabetic ketoacidosis"[Title]', "Endocrinology"),
|
| 58 |
+
('"case report"[Title] AND "stroke"[Title] AND diagnosis', "Neurology"),
|
| 59 |
+
('"case report"[Title] AND "appendicitis"[Title] AND diagnosis', "Surgery"),
|
| 60 |
+
('"case report"[Title] AND "pulmonary embolism"[Title]', "Pulmonology"),
|
| 61 |
+
('"case report"[Title] AND "sepsis"[Title] AND management', "Critical Care"),
|
| 62 |
+
('"case report"[Title] AND "heart failure"[Title] AND management', "Cardiology"),
|
| 63 |
+
('"case report"[Title] AND "pancreatitis"[Title] AND diagnosis', "Gastroenterology"),
|
| 64 |
+
('"case report"[Title] AND "meningitis"[Title] AND diagnosis', "Neurology/ID"),
|
| 65 |
+
('"case report"[Title] AND "urinary tract infection"[Title]', "Urology/ID"),
|
| 66 |
+
('"case report"[Title] AND "thyroid"[Title] AND "nodule"', "Endocrinology"),
|
| 67 |
+
('"case report"[Title] AND "deep vein thrombosis"[Title]', "Hematology"),
|
| 68 |
+
('"case report"[Title] AND "anaphylaxis"[Title]', "Allergy/EM"),
|
| 69 |
+
('"case report"[Title] AND "renal failure"[Title] AND acute', "Nephrology"),
|
| 70 |
+
('"case report"[Title] AND "liver cirrhosis"[Title]', "Hepatology"),
|
| 71 |
+
('"case report"[Title] AND "asthma"[Title] AND exacerbation', "Pulmonology"),
|
| 72 |
+
('"case report"[Title] AND "seizure"[Title] AND diagnosis', "Neurology"),
|
| 73 |
+
('"case report"[Title] AND "hypoglycemia"[Title]', "Endocrinology"),
|
| 74 |
+
('"case report"[Title] AND "gastrointestinal bleeding"[Title]', "Gastroenterology"),
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
async def fetch_pmc_cases(
|
| 79 |
+
max_cases: int = 20,
|
| 80 |
+
seed: int = 42,
|
| 81 |
+
) -> List[ValidationCase]:
|
| 82 |
+
"""
|
| 83 |
+
Fetch case reports from PubMed and convert to ValidationCase objects.
|
| 84 |
+
|
| 85 |
+
Uses PubMed E-utilities to search for case reports with clear diagnoses,
|
| 86 |
+
then extracts the clinical presentation and diagnosis from abstracts.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
max_cases: Maximum number of cases to fetch
|
| 90 |
+
seed: Random seed for reproducible selection
|
| 91 |
+
"""
|
| 92 |
+
ensure_data_dir()
|
| 93 |
+
cache_path = DATA_DIR / "pmc_cases.json"
|
| 94 |
+
|
| 95 |
+
if cache_path.exists():
|
| 96 |
+
print(f" Loading PMC cases from cache: {cache_path}")
|
| 97 |
+
cached = json.loads(cache_path.read_text(encoding="utf-8"))
|
| 98 |
+
cases = [ValidationCase(**c) for c in cached]
|
| 99 |
+
if len(cases) >= max_cases:
|
| 100 |
+
random.seed(seed)
|
| 101 |
+
return random.sample(cases, min(max_cases, len(cases)))
|
| 102 |
+
# Fall through to fetch more if cache is insufficient
|
| 103 |
+
|
| 104 |
+
print(f" Fetching case reports from PubMed...")
|
| 105 |
+
cases = await _fetch_from_pubmed(max_cases, seed)
|
| 106 |
+
|
| 107 |
+
if cases:
|
| 108 |
+
# Cache
|
| 109 |
+
cached_data = [
|
| 110 |
+
{
|
| 111 |
+
"case_id": c.case_id,
|
| 112 |
+
"source_dataset": c.source_dataset,
|
| 113 |
+
"input_text": c.input_text,
|
| 114 |
+
"ground_truth": c.ground_truth,
|
| 115 |
+
"metadata": c.metadata,
|
| 116 |
+
}
|
| 117 |
+
for c in cases
|
| 118 |
+
]
|
| 119 |
+
cache_path.write_text(json.dumps(cached_data, indent=2), encoding="utf-8")
|
| 120 |
+
print(f" Cached {len(cases)} PMC cases to {cache_path}")
|
| 121 |
+
|
| 122 |
+
print(f" Loaded {len(cases)} PMC case reports")
|
| 123 |
+
return cases
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
async def _fetch_from_pubmed(max_cases: int, seed: int) -> List[ValidationCase]:
|
| 127 |
+
"""Fetch case reports via PubMed E-utilities."""
|
| 128 |
+
cases = []
|
| 129 |
+
random.seed(seed)
|
| 130 |
+
queries = random.sample(CASE_REPORT_QUERIES, min(max_cases, len(CASE_REPORT_QUERIES)))
|
| 131 |
+
|
| 132 |
+
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client:
|
| 133 |
+
for query_text, specialty in queries:
|
| 134 |
+
if len(cases) >= max_cases:
|
| 135 |
+
break
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
# Step 1: Search for PMIDs
|
| 139 |
+
pmids = await _esearch(client, query_text, retmax=3)
|
| 140 |
+
if not pmids:
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
# Step 2: Fetch abstracts
|
| 144 |
+
for pmid in pmids[:1]: # Take first result per query
|
| 145 |
+
abstract_data = await _efetch_abstract(client, pmid)
|
| 146 |
+
if not abstract_data:
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
title, abstract = abstract_data
|
| 150 |
+
|
| 151 |
+
# Step 3: Extract case presentation and diagnosis
|
| 152 |
+
presentation, diagnosis = _extract_case_and_diagnosis(title, abstract, query_text)
|
| 153 |
+
if not presentation or not diagnosis:
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
cases.append(ValidationCase(
|
| 157 |
+
case_id=f"pmc_{pmid}",
|
| 158 |
+
source_dataset="pmc",
|
| 159 |
+
input_text=presentation,
|
| 160 |
+
ground_truth={
|
| 161 |
+
"diagnosis": diagnosis,
|
| 162 |
+
"specialty": specialty,
|
| 163 |
+
"title": title,
|
| 164 |
+
},
|
| 165 |
+
metadata={
|
| 166 |
+
"pmid": pmid,
|
| 167 |
+
"full_abstract": abstract,
|
| 168 |
+
},
|
| 169 |
+
))
|
| 170 |
+
|
| 171 |
+
if len(cases) >= max_cases:
|
| 172 |
+
break
|
| 173 |
+
|
| 174 |
+
# NCBI rate limit: max 3 requests/second without API key
|
| 175 |
+
await asyncio.sleep(0.5)
|
| 176 |
+
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f" Warning: Query failed '{query_text[:40]}...': {e}")
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
return cases
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
async def _esearch(client: httpx.AsyncClient, query: str, retmax: int = 3) -> List[str]:
|
| 185 |
+
"""Search PubMed and return PMIDs."""
|
| 186 |
+
params = {
|
| 187 |
+
"db": "pubmed",
|
| 188 |
+
"term": query,
|
| 189 |
+
"retmax": retmax,
|
| 190 |
+
"retmode": "json",
|
| 191 |
+
"sort": "relevance",
|
| 192 |
+
}
|
| 193 |
+
r = await client.get(ESEARCH_URL, params=params)
|
| 194 |
+
r.raise_for_status()
|
| 195 |
+
data = r.json()
|
| 196 |
+
return data.get("esearchresult", {}).get("idlist", [])
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
async def _efetch_abstract(client: httpx.AsyncClient, pmid: str) -> Optional[Tuple[str, str]]:
|
| 200 |
+
"""Fetch the title and abstract for a PMID."""
|
| 201 |
+
params = {
|
| 202 |
+
"db": "pubmed",
|
| 203 |
+
"id": pmid,
|
| 204 |
+
"retmode": "xml",
|
| 205 |
+
}
|
| 206 |
+
r = await client.get(EFETCH_URL, params=params)
|
| 207 |
+
r.raise_for_status()
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
root = ET.fromstring(r.text)
|
| 211 |
+
|
| 212 |
+
# Extract title
|
| 213 |
+
title_el = root.find(".//ArticleTitle")
|
| 214 |
+
title = title_el.text if title_el is not None and title_el.text else ""
|
| 215 |
+
|
| 216 |
+
# Extract abstract
|
| 217 |
+
abstract_parts = []
|
| 218 |
+
for abs_text in root.findall(".//AbstractText"):
|
| 219 |
+
label = abs_text.get("Label", "")
|
| 220 |
+
text = abs_text.text or ""
|
| 221 |
+
# Collect tail text from sub-elements
|
| 222 |
+
full_text = (abs_text.text or "") + "".join(
|
| 223 |
+
(child.text or "") + (child.tail or "") for child in abs_text
|
| 224 |
+
)
|
| 225 |
+
if label:
|
| 226 |
+
abstract_parts.append(f"{label}: {full_text.strip()}")
|
| 227 |
+
else:
|
| 228 |
+
abstract_parts.append(full_text.strip())
|
| 229 |
+
|
| 230 |
+
abstract = " ".join(abstract_parts)
|
| 231 |
+
|
| 232 |
+
if len(abstract) < 100:
|
| 233 |
+
return None
|
| 234 |
+
|
| 235 |
+
return title, abstract
|
| 236 |
+
|
| 237 |
+
except ET.ParseError:
|
| 238 |
+
return None
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _extract_case_and_diagnosis(
|
| 242 |
+
title: str, abstract: str, search_query: str
|
| 243 |
+
) -> Tuple[Optional[str], Optional[str]]:
|
| 244 |
+
"""
|
| 245 |
+
Extract the clinical presentation and final diagnosis from a case report abstract.
|
| 246 |
+
|
| 247 |
+
Strategy:
|
| 248 |
+
1. Try structured abstract sections (CASE PRESENTATION, DIAGNOSIS, etc.)
|
| 249 |
+
2. Extract diagnosis from the title (common pattern: "A case of [diagnosis]")
|
| 250 |
+
3. Fall back to using the search condition as the expected diagnosis
|
| 251 |
+
"""
|
| 252 |
+
# Try to extract diagnosis from title
|
| 253 |
+
diagnosis = None
|
| 254 |
+
title_patterns = [
|
| 255 |
+
r"case (?:report )?of (.+?)(?:\.|:|$)",
|
| 256 |
+
r"presenting (?:as|with) (.+?)(?:\.|:|$)",
|
| 257 |
+
r"diagnosed (?:as|with) (.+?)(?:\.|:|$)",
|
| 258 |
+
r"rare case of (.+?)(?:\.|:|$)",
|
| 259 |
+
r"unusual (?:case|presentation) of (.+?)(?:\.|:|$)",
|
| 260 |
+
# Pattern: "Diagnosis Name: A Case Report"
|
| 261 |
+
r"^(.+?):\s*[Aa]\s*[Cc]ase\s*[Rr]eport",
|
| 262 |
+
# Pattern: "Diagnosis Name - Case Report"
|
| 263 |
+
r"^(.+?)\s*[-ββ]\s*[Cc]ase\s*[Rr]eport",
|
| 264 |
+
# Pattern: "Case of Diagnosis Name"
|
| 265 |
+
r"[Cc]ase\s+of\s+(.+?)(?:\.|:|,|$)",
|
| 266 |
+
]
|
| 267 |
+
for pattern in title_patterns:
|
| 268 |
+
match = re.search(pattern, title, re.IGNORECASE)
|
| 269 |
+
if match:
|
| 270 |
+
diagnosis = match.group(1).strip()
|
| 271 |
+
break
|
| 272 |
+
|
| 273 |
+
if not diagnosis:
|
| 274 |
+
# Extract from search query
|
| 275 |
+
# queries look like: '"case report"[Title] AND "myocardial infarction"[Title]'
|
| 276 |
+
# Find all quoted terms and pick the one that isn't "case report"
|
| 277 |
+
matches = re.findall(r'"([^"]+)"', search_query)
|
| 278 |
+
for m in matches:
|
| 279 |
+
if m.lower() != "case report":
|
| 280 |
+
diagnosis = m
|
| 281 |
+
break
|
| 282 |
+
|
| 283 |
+
if not diagnosis:
|
| 284 |
+
return None, None
|
| 285 |
+
|
| 286 |
+
# Clean diagnosis text
|
| 287 |
+
diagnosis = diagnosis.strip().rstrip('.')
|
| 288 |
+
|
| 289 |
+
# Extract clinical presentation
|
| 290 |
+
# For structured abstracts, look for specific sections
|
| 291 |
+
presentation_sections = ["CASE PRESENTATION", "CASE REPORT", "CASE", "CLINICAL PRESENTATION", "HISTORY"]
|
| 292 |
+
conclusion_sections = ["CONCLUSION", "DISCUSSION", "OUTCOME", "DIAGNOSIS", "RESULTS"]
|
| 293 |
+
|
| 294 |
+
# Try to split abstract into presentation vs conclusion
|
| 295 |
+
presentation = abstract
|
| 296 |
+
|
| 297 |
+
# Look for section boundaries in structured abstracts
|
| 298 |
+
for cs in conclusion_sections:
|
| 299 |
+
pattern = re.compile(rf'\b{cs}\b[:\s]', re.IGNORECASE)
|
| 300 |
+
match = pattern.search(abstract)
|
| 301 |
+
if match:
|
| 302 |
+
# Everything before the conclusion is the presentation
|
| 303 |
+
candidate = abstract[:match.start()].strip()
|
| 304 |
+
if len(candidate) > 100:
|
| 305 |
+
presentation = candidate
|
| 306 |
+
break
|
| 307 |
+
|
| 308 |
+
# Clean up
|
| 309 |
+
presentation = presentation.strip()
|
| 310 |
+
if len(presentation) < 50:
|
| 311 |
+
presentation = abstract # Use full abstract if extraction is too short
|
| 312 |
+
|
| 313 |
+
return presentation, diagnosis
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 317 |
+
# Validation harness
|
| 318 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 319 |
+
|
| 320 |
+
async def validate_pmc(
|
| 321 |
+
cases: List[ValidationCase],
|
| 322 |
+
include_drug_check: bool = True,
|
| 323 |
+
include_guidelines: bool = True,
|
| 324 |
+
delay_between_cases: float = 2.0,
|
| 325 |
+
) -> ValidationSummary:
|
| 326 |
+
"""
|
| 327 |
+
Run PMC case reports through the CDS pipeline and score results.
|
| 328 |
+
"""
|
| 329 |
+
results: List[ValidationResult] = []
|
| 330 |
+
start_time = time.time()
|
| 331 |
+
|
| 332 |
+
for i, case in enumerate(cases):
|
| 333 |
+
dx = case.ground_truth.get("diagnosis", "?")
|
| 334 |
+
specialty = case.ground_truth.get("specialty", "?")
|
| 335 |
+
print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty} β {dx[:40]}): ", end="", flush=True)
|
| 336 |
+
|
| 337 |
+
case_start = time.monotonic()
|
| 338 |
+
|
| 339 |
+
state, report, error = await run_cds_pipeline(
|
| 340 |
+
patient_text=case.input_text,
|
| 341 |
+
include_drug_check=include_drug_check,
|
| 342 |
+
include_guidelines=include_guidelines,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
elapsed_ms = int((time.monotonic() - case_start) * 1000)
|
| 346 |
+
|
| 347 |
+
step_results = {}
|
| 348 |
+
if state:
|
| 349 |
+
step_results = {s.step_id: s.status.value for s in state.steps}
|
| 350 |
+
|
| 351 |
+
scores = {}
|
| 352 |
+
details = {}
|
| 353 |
+
target_diagnosis = case.ground_truth["diagnosis"]
|
| 354 |
+
|
| 355 |
+
if report:
|
| 356 |
+
# Diagnostic accuracy (anywhere in differential)
|
| 357 |
+
found_any, rank_any = diagnosis_in_differential(target_diagnosis, report)
|
| 358 |
+
scores["diagnostic_accuracy"] = 1.0 if found_any else 0.0
|
| 359 |
+
|
| 360 |
+
# Top-3 accuracy
|
| 361 |
+
found_top3, rank3 = diagnosis_in_differential(target_diagnosis, report, top_n=3)
|
| 362 |
+
scores["top3_accuracy"] = 1.0 if found_top3 else 0.0
|
| 363 |
+
|
| 364 |
+
# Top-1 accuracy
|
| 365 |
+
found_top1, rank1 = diagnosis_in_differential(target_diagnosis, report, top_n=1)
|
| 366 |
+
scores["top1_accuracy"] = 1.0 if found_top1 else 0.0
|
| 367 |
+
|
| 368 |
+
# Parse success
|
| 369 |
+
scores["parse_success"] = 1.0
|
| 370 |
+
|
| 371 |
+
# Has recommendations
|
| 372 |
+
scores["has_recommendations"] = 1.0 if len(report.suggested_next_steps) > 0 else 0.0
|
| 373 |
+
|
| 374 |
+
details = {
|
| 375 |
+
"target_diagnosis": target_diagnosis,
|
| 376 |
+
"top_diagnosis": report.differential_diagnosis[0].diagnosis if report.differential_diagnosis else "NONE",
|
| 377 |
+
"num_diagnoses": len(report.differential_diagnosis),
|
| 378 |
+
"found_at_rank": rank_any if found_any else -1,
|
| 379 |
+
"all_diagnoses": [d.diagnosis for d in report.differential_diagnosis[:5]],
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
icon = "β" if found_any else "β"
|
| 383 |
+
top_dx = report.differential_diagnosis[0].diagnosis if report.differential_diagnosis else "NONE"
|
| 384 |
+
print(f"{icon} top1={'Y' if found_top1 else 'N'} diag={'Y' if found_any else 'N'} | top: {top_dx[:30]} ({elapsed_ms}ms)")
|
| 385 |
+
else:
|
| 386 |
+
scores = {
|
| 387 |
+
"diagnostic_accuracy": 0.0,
|
| 388 |
+
"top3_accuracy": 0.0,
|
| 389 |
+
"top1_accuracy": 0.0,
|
| 390 |
+
"parse_success": 0.0,
|
| 391 |
+
"has_recommendations": 0.0,
|
| 392 |
+
}
|
| 393 |
+
details = {"target_diagnosis": target_diagnosis, "error": error}
|
| 394 |
+
print(f"β FAILED: {error[:80] if error else 'unknown'}")
|
| 395 |
+
|
| 396 |
+
results.append(ValidationResult(
|
| 397 |
+
case_id=case.case_id,
|
| 398 |
+
source_dataset="pmc",
|
| 399 |
+
success=report is not None,
|
| 400 |
+
scores=scores,
|
| 401 |
+
pipeline_time_ms=elapsed_ms,
|
| 402 |
+
step_results=step_results,
|
| 403 |
+
report_summary=report.patient_summary[:200] if report else None,
|
| 404 |
+
error=error,
|
| 405 |
+
details=details,
|
| 406 |
+
))
|
| 407 |
+
|
| 408 |
+
if i < len(cases) - 1:
|
| 409 |
+
await asyncio.sleep(delay_between_cases)
|
| 410 |
+
|
| 411 |
+
# Aggregate
|
| 412 |
+
total = len(results)
|
| 413 |
+
successful = sum(1 for r in results if r.success)
|
| 414 |
+
|
| 415 |
+
metric_names = ["diagnostic_accuracy", "top3_accuracy", "top1_accuracy", "parse_success", "has_recommendations"]
|
| 416 |
+
metrics = {}
|
| 417 |
+
for m in metric_names:
|
| 418 |
+
values = [r.scores.get(m, 0.0) for r in results]
|
| 419 |
+
metrics[m] = sum(values) / len(values) if values else 0.0
|
| 420 |
+
|
| 421 |
+
times = [r.pipeline_time_ms for r in results if r.success]
|
| 422 |
+
metrics["avg_pipeline_time_ms"] = sum(times) / len(times) if times else 0
|
| 423 |
+
|
| 424 |
+
summary = ValidationSummary(
|
| 425 |
+
dataset="pmc",
|
| 426 |
+
total_cases=total,
|
| 427 |
+
successful_cases=successful,
|
| 428 |
+
failed_cases=total - successful,
|
| 429 |
+
metrics=metrics,
|
| 430 |
+
per_case=results,
|
| 431 |
+
run_duration_sec=time.time() - start_time,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
return summary
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 438 |
+
# Standalone runner
|
| 439 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 440 |
+
|
| 441 |
+
async def main():
|
| 442 |
+
"""Run PMC Case Reports validation standalone."""
|
| 443 |
+
import argparse
|
| 444 |
+
|
| 445 |
+
parser = argparse.ArgumentParser(description="PMC Case Reports Validation")
|
| 446 |
+
parser.add_argument("--max-cases", type=int, default=10, help="Number of cases to evaluate")
|
| 447 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 448 |
+
parser.add_argument("--no-drugs", action="store_true", help="Skip drug interaction check")
|
| 449 |
+
parser.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
|
| 450 |
+
parser.add_argument("--delay", type=float, default=2.0, help="Delay between cases (seconds)")
|
| 451 |
+
args = parser.parse_args()
|
| 452 |
+
|
| 453 |
+
print("PMC Case Reports Validation Harness")
|
| 454 |
+
print("=" * 40)
|
| 455 |
+
|
| 456 |
+
cases = await fetch_pmc_cases(max_cases=args.max_cases, seed=args.seed)
|
| 457 |
+
summary = await validate_pmc(
|
| 458 |
+
cases,
|
| 459 |
+
include_drug_check=not args.no_drugs,
|
| 460 |
+
include_guidelines=not args.no_guidelines,
|
| 461 |
+
delay_between_cases=args.delay,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
print_summary(summary)
|
| 465 |
+
path = save_results(summary)
|
| 466 |
+
print(f"Results saved to: {path}")
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
if __name__ == "__main__":
|
| 470 |
+
asyncio.run(main())
|
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unified validation runner for the Clinical Decision Support Agent.
|
| 3 |
+
|
| 4 |
+
Runs all three dataset validations (MedQA, MTSamples, PMC Case Reports)
|
| 5 |
+
and produces a combined summary report.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# From src/backend directory:
|
| 9 |
+
python -m validation.run_validation --all --max-cases 10
|
| 10 |
+
python -m validation.run_validation --medqa --max-cases 20
|
| 11 |
+
python -m validation.run_validation --mtsamples --max-cases 15
|
| 12 |
+
python -m validation.run_validation --pmc --max-cases 10
|
| 13 |
+
|
| 14 |
+
# Fetch data only (no pipeline execution):
|
| 15 |
+
python -m validation.run_validation --fetch-only
|
| 16 |
+
"""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import asyncio
|
| 20 |
+
import json
|
| 21 |
+
import sys
|
| 22 |
+
import time
|
| 23 |
+
from datetime import datetime, timezone
|
| 24 |
+
from pathlib import Path
|
| 25 |
+
|
| 26 |
+
# Ensure backend is importable
|
| 27 |
+
BACKEND_DIR = Path(__file__).resolve().parent.parent
|
| 28 |
+
if str(BACKEND_DIR) not in sys.path:
|
| 29 |
+
sys.path.insert(0, str(BACKEND_DIR))
|
| 30 |
+
|
| 31 |
+
from validation.base import (
|
| 32 |
+
ValidationSummary,
|
| 33 |
+
print_summary,
|
| 34 |
+
save_results,
|
| 35 |
+
)
|
| 36 |
+
from validation.harness_medqa import fetch_medqa, validate_medqa
|
| 37 |
+
from validation.harness_mtsamples import fetch_mtsamples, validate_mtsamples
|
| 38 |
+
from validation.harness_pmc import fetch_pmc_cases, validate_pmc
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
async def run_all_validations(
|
| 42 |
+
run_medqa: bool = True,
|
| 43 |
+
run_mtsamples: bool = True,
|
| 44 |
+
run_pmc: bool = True,
|
| 45 |
+
max_cases: int = 10,
|
| 46 |
+
seed: int = 42,
|
| 47 |
+
include_drug_check: bool = True,
|
| 48 |
+
include_guidelines: bool = True,
|
| 49 |
+
delay: float = 2.0,
|
| 50 |
+
fetch_only: bool = False,
|
| 51 |
+
) -> dict:
|
| 52 |
+
"""
|
| 53 |
+
Run validation against selected datasets.
|
| 54 |
+
|
| 55 |
+
Returns dict of {dataset_name: ValidationSummary}
|
| 56 |
+
"""
|
| 57 |
+
results = {}
|
| 58 |
+
start = time.time()
|
| 59 |
+
|
| 60 |
+
# ββ MedQA ββ
|
| 61 |
+
if run_medqa:
|
| 62 |
+
print("\n" + "=" * 60)
|
| 63 |
+
print(" DATASET 1: MedQA (USMLE-style diagnostic accuracy)")
|
| 64 |
+
print("=" * 60)
|
| 65 |
+
|
| 66 |
+
cases = await fetch_medqa(max_cases=max_cases, seed=seed)
|
| 67 |
+
|
| 68 |
+
if fetch_only:
|
| 69 |
+
print(f" Fetched {len(cases)} MedQA cases (fetch-only mode)")
|
| 70 |
+
else:
|
| 71 |
+
summary = await validate_medqa(
|
| 72 |
+
cases,
|
| 73 |
+
include_drug_check=include_drug_check,
|
| 74 |
+
include_guidelines=include_guidelines,
|
| 75 |
+
delay_between_cases=delay,
|
| 76 |
+
)
|
| 77 |
+
print_summary(summary)
|
| 78 |
+
save_results(summary)
|
| 79 |
+
results["medqa"] = summary
|
| 80 |
+
|
| 81 |
+
# ββ MTSamples ββ
|
| 82 |
+
if run_mtsamples:
|
| 83 |
+
print("\n" + "=" * 60)
|
| 84 |
+
print(" DATASET 2: MTSamples (clinical note parsing robustness)")
|
| 85 |
+
print("=" * 60)
|
| 86 |
+
|
| 87 |
+
cases = await fetch_mtsamples(max_cases=max_cases, seed=seed)
|
| 88 |
+
|
| 89 |
+
if fetch_only:
|
| 90 |
+
print(f" Fetched {len(cases)} MTSamples cases (fetch-only mode)")
|
| 91 |
+
else:
|
| 92 |
+
summary = await validate_mtsamples(
|
| 93 |
+
cases,
|
| 94 |
+
include_drug_check=include_drug_check,
|
| 95 |
+
include_guidelines=include_guidelines,
|
| 96 |
+
delay_between_cases=delay,
|
| 97 |
+
)
|
| 98 |
+
print_summary(summary)
|
| 99 |
+
save_results(summary)
|
| 100 |
+
results["mtsamples"] = summary
|
| 101 |
+
|
| 102 |
+
# ββ PMC Case Reports ββ
|
| 103 |
+
if run_pmc:
|
| 104 |
+
print("\n" + "=" * 60)
|
| 105 |
+
print(" DATASET 3: PMC Case Reports (real-world diagnostic accuracy)")
|
| 106 |
+
print("=" * 60)
|
| 107 |
+
|
| 108 |
+
cases = await fetch_pmc_cases(max_cases=max_cases, seed=seed)
|
| 109 |
+
|
| 110 |
+
if fetch_only:
|
| 111 |
+
print(f" Fetched {len(cases)} PMC cases (fetch-only mode)")
|
| 112 |
+
else:
|
| 113 |
+
summary = await validate_pmc(
|
| 114 |
+
cases,
|
| 115 |
+
include_drug_check=include_drug_check,
|
| 116 |
+
include_guidelines=include_guidelines,
|
| 117 |
+
delay_between_cases=delay,
|
| 118 |
+
)
|
| 119 |
+
print_summary(summary)
|
| 120 |
+
save_results(summary)
|
| 121 |
+
results["pmc"] = summary
|
| 122 |
+
|
| 123 |
+
# ββ Combined Summary ββ
|
| 124 |
+
total_duration = time.time() - start
|
| 125 |
+
|
| 126 |
+
if results and not fetch_only:
|
| 127 |
+
_print_combined_summary(results, total_duration)
|
| 128 |
+
_save_combined_report(results, total_duration)
|
| 129 |
+
|
| 130 |
+
return results
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _print_combined_summary(results: dict, total_duration: float):
|
| 134 |
+
"""Print a combined summary across all datasets."""
|
| 135 |
+
print("\n" + "=" * 70)
|
| 136 |
+
print(" COMBINED VALIDATION REPORT")
|
| 137 |
+
print("=" * 70)
|
| 138 |
+
|
| 139 |
+
# Header
|
| 140 |
+
print(f"\n {'Dataset':<15} {'Cases':>6} {'Success':>8} {'Key Metric':>25} {'Value':>8}")
|
| 141 |
+
print(f" {'-'*15} {'-'*6} {'-'*8} {'-'*25} {'-'*8}")
|
| 142 |
+
|
| 143 |
+
for name, summary in results.items():
|
| 144 |
+
# Pick the most important metric for each dataset
|
| 145 |
+
if name == "medqa":
|
| 146 |
+
key_metric = "top3_accuracy"
|
| 147 |
+
elif name == "mtsamples":
|
| 148 |
+
key_metric = "parse_success"
|
| 149 |
+
elif name == "pmc":
|
| 150 |
+
key_metric = "diagnostic_accuracy"
|
| 151 |
+
else:
|
| 152 |
+
key_metric = list(summary.metrics.keys())[0] if summary.metrics else "N/A"
|
| 153 |
+
|
| 154 |
+
value = summary.metrics.get(key_metric, 0.0)
|
| 155 |
+
print(
|
| 156 |
+
f" {name:<15} {summary.total_cases:>6} "
|
| 157 |
+
f"{summary.successful_cases:>8} "
|
| 158 |
+
f"{key_metric:>25} {value:>7.1%}"
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# All metrics
|
| 162 |
+
print(f"\n {'β' * 66}")
|
| 163 |
+
for name, summary in results.items():
|
| 164 |
+
print(f"\n {name.upper()} metrics:")
|
| 165 |
+
for metric, value in sorted(summary.metrics.items()):
|
| 166 |
+
if "time" in metric and isinstance(value, (int, float)):
|
| 167 |
+
print(f" {metric:<35} {value:.0f}ms")
|
| 168 |
+
elif isinstance(value, float):
|
| 169 |
+
print(f" {metric:<35} {value:.1%}")
|
| 170 |
+
|
| 171 |
+
# Totals
|
| 172 |
+
total_cases = sum(s.total_cases for s in results.values())
|
| 173 |
+
total_success = sum(s.successful_cases for s in results.values())
|
| 174 |
+
print(f"\n Total cases: {total_cases}")
|
| 175 |
+
print(f" Total success: {total_success}")
|
| 176 |
+
print(f" Total duration: {total_duration:.1f}s ({total_duration/60:.1f}min)")
|
| 177 |
+
print(f" Timestamp: {datetime.now(timezone.utc).isoformat()}")
|
| 178 |
+
print("=" * 70)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _save_combined_report(results: dict, total_duration: float):
|
| 182 |
+
"""Save combined report to JSON."""
|
| 183 |
+
results_dir = Path(__file__).resolve().parent / "results"
|
| 184 |
+
results_dir.mkdir(parents=True, exist_ok=True)
|
| 185 |
+
|
| 186 |
+
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
| 187 |
+
path = results_dir / f"combined_{ts}.json"
|
| 188 |
+
|
| 189 |
+
combined = {
|
| 190 |
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
| 191 |
+
"total_duration_sec": total_duration,
|
| 192 |
+
"datasets": {},
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
for name, summary in results.items():
|
| 196 |
+
combined["datasets"][name] = {
|
| 197 |
+
"total_cases": summary.total_cases,
|
| 198 |
+
"successful_cases": summary.successful_cases,
|
| 199 |
+
"failed_cases": summary.failed_cases,
|
| 200 |
+
"metrics": summary.metrics,
|
| 201 |
+
"run_duration_sec": summary.run_duration_sec,
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
path.write_text(json.dumps(combined, indent=2, default=str))
|
| 205 |
+
print(f"\n Combined report saved to: {path}")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def main():
|
| 209 |
+
"""CLI entry point."""
|
| 210 |
+
import argparse
|
| 211 |
+
|
| 212 |
+
parser = argparse.ArgumentParser(
|
| 213 |
+
description="CDS Agent Validation Suite",
|
| 214 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 215 |
+
epilog="""
|
| 216 |
+
Examples:
|
| 217 |
+
python -m validation.run_validation --all --max-cases 10
|
| 218 |
+
python -m validation.run_validation --medqa --max-cases 50
|
| 219 |
+
python -m validation.run_validation --fetch-only
|
| 220 |
+
python -m validation.run_validation --medqa --pmc --max-cases 20 --no-drugs
|
| 221 |
+
""",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Dataset selection
|
| 225 |
+
data_group = parser.add_argument_group("Datasets")
|
| 226 |
+
data_group.add_argument("--all", action="store_true", help="Run all three datasets")
|
| 227 |
+
data_group.add_argument("--medqa", action="store_true", help="Run MedQA validation")
|
| 228 |
+
data_group.add_argument("--mtsamples", action="store_true", help="Run MTSamples validation")
|
| 229 |
+
data_group.add_argument("--pmc", action="store_true", help="Run PMC Case Reports validation")
|
| 230 |
+
|
| 231 |
+
# Configuration
|
| 232 |
+
config_group = parser.add_argument_group("Configuration")
|
| 233 |
+
config_group.add_argument("--max-cases", type=int, default=10, help="Cases per dataset (default: 10)")
|
| 234 |
+
config_group.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
|
| 235 |
+
config_group.add_argument("--delay", type=float, default=2.0, help="Delay between cases in seconds (default: 2.0)")
|
| 236 |
+
config_group.add_argument("--no-drugs", action="store_true", help="Skip drug interaction checks")
|
| 237 |
+
config_group.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
|
| 238 |
+
config_group.add_argument("--fetch-only", action="store_true", help="Only download data, don't run pipeline")
|
| 239 |
+
|
| 240 |
+
args = parser.parse_args()
|
| 241 |
+
|
| 242 |
+
# Default to --all if nothing specified
|
| 243 |
+
if not any([args.all, args.medqa, args.mtsamples, args.pmc]):
|
| 244 |
+
args.all = True
|
| 245 |
+
|
| 246 |
+
run_medqa = args.all or args.medqa
|
| 247 |
+
run_mtsamples = args.all or args.mtsamples
|
| 248 |
+
run_pmc = args.all or args.pmc
|
| 249 |
+
|
| 250 |
+
print("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 251 |
+
print("β Clinical Decision Support Agent β Validation Suite β")
|
| 252 |
+
print("ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ")
|
| 253 |
+
print(f"\n Datasets: {'MedQA ' if run_medqa else ''}{'MTSamples ' if run_mtsamples else ''}{'PMC ' if run_pmc else ''}")
|
| 254 |
+
print(f" Cases/dataset: {args.max_cases}")
|
| 255 |
+
print(f" Drug check: {'Yes' if not args.no_drugs else 'No'}")
|
| 256 |
+
print(f" Guidelines: {'Yes' if not args.no_guidelines else 'No'}")
|
| 257 |
+
print(f" Fetch only: {'Yes' if args.fetch_only else 'No'}")
|
| 258 |
+
|
| 259 |
+
asyncio.run(run_all_validations(
|
| 260 |
+
run_medqa=run_medqa,
|
| 261 |
+
run_mtsamples=run_mtsamples,
|
| 262 |
+
run_pmc=run_pmc,
|
| 263 |
+
max_cases=args.max_cases,
|
| 264 |
+
seed=args.seed,
|
| 265 |
+
include_drug_check=not args.no_drugs,
|
| 266 |
+
include_guidelines=not args.no_guidelines,
|
| 267 |
+
delay=args.delay,
|
| 268 |
+
fetch_only=args.fetch_only,
|
| 269 |
+
))
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
main()
|