File size: 10,273 Bytes
393ff7f 1f36481 393ff7f 1f36481 0fe4d92 1f36481 393ff7f 3d02eb2 393ff7f 3d02eb2 393ff7f 3d02eb2 393ff7f 3d02eb2 393ff7f 1f36481 393ff7f 3d02eb2 393ff7f 1f36481 393ff7f 3d02eb2 393ff7f 3d02eb2 393ff7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
"""
Unified validation runner for the Clinical Decision Support Agent.
Runs all three dataset validations (MedQA, MTSamples, PMC Case Reports)
and produces a combined summary report.
Usage:
# From src/backend directory:
python -m validation.run_validation --all --max-cases 10
python -m validation.run_validation --medqa --max-cases 20
python -m validation.run_validation --mtsamples --max-cases 15
python -m validation.run_validation --pmc --max-cases 10
# Fetch data only (no pipeline execution):
python -m validation.run_validation --fetch-only
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
# Ensure backend is importable
BACKEND_DIR = Path(__file__).resolve().parent.parent
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
# Load .env and export HF_TOKEN so huggingface_hub picks it up
from dotenv import load_dotenv
load_dotenv(BACKEND_DIR / ".env")
hf_token = os.getenv("HF_TOKEN")
if hf_token:
os.environ["HF_TOKEN"] = hf_token
from validation.base import (
ValidationSummary,
print_summary,
save_results,
)
from validation.harness_medqa import fetch_medqa, validate_medqa
from validation.harness_mtsamples import fetch_mtsamples, validate_mtsamples
from validation.harness_pmc import fetch_pmc_cases, validate_pmc
async def run_all_validations(
run_medqa: bool = True,
run_mtsamples: bool = True,
run_pmc: bool = True,
max_cases: int = 10,
seed: int = 42,
include_drug_check: bool = True,
include_guidelines: bool = True,
delay: float = 2.0,
fetch_only: bool = False,
resume: bool = False,
) -> dict:
"""
Run validation against selected datasets.
Returns dict of {dataset_name: ValidationSummary}
"""
results = {}
start = time.time()
# ββ MedQA ββ
if run_medqa:
print("\n" + "=" * 60)
print(" DATASET 1: MedQA (USMLE-style diagnostic accuracy)")
print("=" * 60)
cases = await fetch_medqa(max_cases=max_cases, seed=seed)
if fetch_only:
print(f" Fetched {len(cases)} MedQA cases (fetch-only mode)")
else:
summary = await validate_medqa(
cases,
include_drug_check=include_drug_check,
include_guidelines=include_guidelines,
delay_between_cases=delay,
resume=resume,
)
print_summary(summary)
save_results(summary)
results["medqa"] = summary
# ββ MTSamples ββ
if run_mtsamples:
print("\n" + "=" * 60)
print(" DATASET 2: MTSamples (clinical note parsing robustness)")
print("=" * 60)
cases = await fetch_mtsamples(max_cases=max_cases, seed=seed)
if fetch_only:
print(f" Fetched {len(cases)} MTSamples cases (fetch-only mode)")
else:
summary = await validate_mtsamples(
cases,
include_drug_check=include_drug_check,
include_guidelines=include_guidelines,
delay_between_cases=delay,
resume=resume,
)
print_summary(summary)
save_results(summary)
results["mtsamples"] = summary
# ββ PMC Case Reports ββ
if run_pmc:
print("\n" + "=" * 60)
print(" DATASET 3: PMC Case Reports (real-world diagnostic accuracy)")
print("=" * 60)
cases = await fetch_pmc_cases(max_cases=max_cases, seed=seed)
if fetch_only:
print(f" Fetched {len(cases)} PMC cases (fetch-only mode)")
else:
summary = await validate_pmc(
cases,
include_drug_check=include_drug_check,
include_guidelines=include_guidelines,
delay_between_cases=delay,
resume=resume,
)
print_summary(summary)
save_results(summary)
results["pmc"] = summary
# ββ Combined Summary ββ
total_duration = time.time() - start
if results and not fetch_only:
_print_combined_summary(results, total_duration)
_save_combined_report(results, total_duration)
return results
def _print_combined_summary(results: dict, total_duration: float):
"""Print a combined summary across all datasets."""
print("\n" + "=" * 70)
print(" COMBINED VALIDATION REPORT")
print("=" * 70)
# Header
print(f"\n {'Dataset':<15} {'Cases':>6} {'Success':>8} {'Key Metric':>25} {'Value':>8}")
print(f" {'-'*15} {'-'*6} {'-'*8} {'-'*25} {'-'*8}")
for name, summary in results.items():
# Pick the most important metric for each dataset
if name == "medqa":
key_metric = "top3_accuracy"
elif name == "mtsamples":
key_metric = "parse_success"
elif name == "pmc":
key_metric = "diagnostic_accuracy"
else:
key_metric = list(summary.metrics.keys())[0] if summary.metrics else "N/A"
value = summary.metrics.get(key_metric, 0.0)
print(
f" {name:<15} {summary.total_cases:>6} "
f"{summary.successful_cases:>8} "
f"{key_metric:>25} {value:>7.1%}"
)
# All metrics
print(f"\n {'-' * 66}")
for name, summary in results.items():
print(f"\n {name.upper()} metrics:")
for metric, value in sorted(summary.metrics.items()):
if "time" in metric and isinstance(value, (int, float)):
print(f" {metric:<35} {value:.0f}ms")
elif isinstance(value, float):
print(f" {metric:<35} {value:.1%}")
# Totals
total_cases = sum(s.total_cases for s in results.values())
total_success = sum(s.successful_cases for s in results.values())
print(f"\n Total cases: {total_cases}")
print(f" Total success: {total_success}")
print(f" Total duration: {total_duration:.1f}s ({total_duration/60:.1f}min)")
print(f" Timestamp: {datetime.now(timezone.utc).isoformat()}")
print("=" * 70)
def _save_combined_report(results: dict, total_duration: float):
"""Save combined report to JSON."""
results_dir = Path(__file__).resolve().parent / "results"
results_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
path = results_dir / f"combined_{ts}.json"
combined = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"total_duration_sec": total_duration,
"datasets": {},
}
for name, summary in results.items():
combined["datasets"][name] = {
"total_cases": summary.total_cases,
"successful_cases": summary.successful_cases,
"failed_cases": summary.failed_cases,
"metrics": summary.metrics,
"run_duration_sec": summary.run_duration_sec,
}
path.write_text(json.dumps(combined, indent=2, default=str))
print(f"\n Combined report saved to: {path}")
def main():
"""CLI entry point."""
import argparse
parser = argparse.ArgumentParser(
description="CDS Agent Validation Suite",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m validation.run_validation --all --max-cases 10
python -m validation.run_validation --medqa --max-cases 50
python -m validation.run_validation --fetch-only
python -m validation.run_validation --medqa --pmc --max-cases 20 --no-drugs
""",
)
# Dataset selection
data_group = parser.add_argument_group("Datasets")
data_group.add_argument("--all", action="store_true", help="Run all three datasets")
data_group.add_argument("--medqa", action="store_true", help="Run MedQA validation")
data_group.add_argument("--mtsamples", action="store_true", help="Run MTSamples validation")
data_group.add_argument("--pmc", action="store_true", help="Run PMC Case Reports validation")
# Configuration
config_group = parser.add_argument_group("Configuration")
config_group.add_argument("--max-cases", type=int, default=10, help="Cases per dataset (default: 10)")
config_group.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
config_group.add_argument("--delay", type=float, default=2.0, help="Delay between cases in seconds (default: 2.0)")
config_group.add_argument("--no-drugs", action="store_true", help="Skip drug interaction checks")
config_group.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
config_group.add_argument("--resume", action="store_true", help="Resume from checkpoint (skip already-completed cases)")
config_group.add_argument("--fetch-only", action="store_true", help="Only download data, don't run pipeline")
args = parser.parse_args()
# Default to --all if nothing specified
if not any([args.all, args.medqa, args.mtsamples, args.pmc]):
args.all = True
run_medqa = args.all or args.medqa
run_mtsamples = args.all or args.mtsamples
run_pmc = args.all or args.pmc
print("=" * 58)
print(" Clinical Decision Support Agent - Validation Suite")
print("=" * 58)
print(f"\n Datasets: {'MedQA ' if run_medqa else ''}{'MTSamples ' if run_mtsamples else ''}{'PMC ' if run_pmc else ''}")
print(f" Cases/dataset: {args.max_cases}")
print(f" Drug check: {'Yes' if not args.no_drugs else 'No'}")
print(f" Guidelines: {'Yes' if not args.no_guidelines else 'No'}")
print(f" Resume: {'Yes' if args.resume else 'No'}")
print(f" Fetch only: {'Yes' if args.fetch_only else 'No'}")
asyncio.run(run_all_validations(
run_medqa=run_medqa,
run_mtsamples=run_mtsamples,
run_pmc=run_pmc,
max_cases=args.max_cases,
seed=args.seed,
include_drug_check=not args.no_drugs,
include_guidelines=not args.no_guidelines,
delay=args.delay,
fetch_only=args.fetch_only,
resume=args.resume,
))
if __name__ == "__main__":
main()
|