bshepp
commited on
Commit
·
3d02eb2
1
Parent(s):
8aed835
Add incremental checkpoint saves, --resume flag, fix enum case-sensitivity, add HF_TOKEN to template
Browse files- src/backend/.env.template +4 -0
- src/backend/app/models/schemas.py +11 -1
- src/backend/validation/base.py +76 -16
- src/backend/validation/harness_medqa.py +24 -2
- src/backend/validation/harness_mtsamples.py +23 -2
- src/backend/validation/harness_pmc.py +23 -2
- src/backend/validation/run_validation.py +7 -0
src/backend/.env.template
CHANGED
|
@@ -17,6 +17,10 @@ MEDGEMMA_MODEL_ID=google/medgemma-27b-text-it
|
|
| 17 |
# OpenFDA (no key required for basic use, add for higher rate limits)
|
| 18 |
# OPENFDA_API_KEY=
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# --- RAG Configuration ---
|
| 21 |
CHROMA_PERSIST_DIR=./data/chroma
|
| 22 |
EMBEDDING_MODEL=all-MiniLM-L6-v2
|
|
|
|
| 17 |
# OpenFDA (no key required for basic use, add for higher rate limits)
|
| 18 |
# OPENFDA_API_KEY=
|
| 19 |
|
| 20 |
+
# HuggingFace token (for downloading datasets without rate limits)
|
| 21 |
+
# Get yours at: https://huggingface.co/settings/tokens
|
| 22 |
+
# HF_TOKEN=hf_your_token_here
|
| 23 |
+
|
| 24 |
# --- RAG Configuration ---
|
| 25 |
CHROMA_PERSIST_DIR=./data/chroma
|
| 26 |
EMBEDDING_MODEL=all-MiniLM-L6-v2
|
src/backend/app/models/schemas.py
CHANGED
|
@@ -9,7 +9,7 @@ from __future__ import annotations
|
|
| 9 |
from datetime import date, datetime
|
| 10 |
from enum import Enum
|
| 11 |
from typing import List, Optional
|
| 12 |
-
from pydantic import BaseModel, Field
|
| 13 |
|
| 14 |
|
| 15 |
# ──────────────────────────────────────────────
|
|
@@ -172,6 +172,16 @@ class ClinicalConflict(BaseModel):
|
|
| 172 |
"""A single detected conflict between guidelines and patient data."""
|
| 173 |
conflict_type: ConflictType = Field(..., description="Category of the conflict")
|
| 174 |
severity: Severity = Field(..., description="Potential clinical impact")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
guideline_source: str = Field(..., description="Which guideline flagged this")
|
| 176 |
guideline_text: str = Field(..., description="What the guideline recommends")
|
| 177 |
patient_data: str = Field(..., description="Relevant patient data that conflicts")
|
|
|
|
| 9 |
from datetime import date, datetime
|
| 10 |
from enum import Enum
|
| 11 |
from typing import List, Optional
|
| 12 |
+
from pydantic import BaseModel, Field, field_validator
|
| 13 |
|
| 14 |
|
| 15 |
# ──────────────────────────────────────────────
|
|
|
|
| 172 |
"""A single detected conflict between guidelines and patient data."""
|
| 173 |
conflict_type: ConflictType = Field(..., description="Category of the conflict")
|
| 174 |
severity: Severity = Field(..., description="Potential clinical impact")
|
| 175 |
+
|
| 176 |
+
@field_validator("conflict_type", mode="before")
|
| 177 |
+
@classmethod
|
| 178 |
+
def _normalise_conflict_type(cls, v: str) -> str:
|
| 179 |
+
return v.lower() if isinstance(v, str) else v
|
| 180 |
+
|
| 181 |
+
@field_validator("severity", mode="before")
|
| 182 |
+
@classmethod
|
| 183 |
+
def _normalise_severity(cls, v: str) -> str:
|
| 184 |
+
return v.lower() if isinstance(v, str) else v
|
| 185 |
guideline_source: str = Field(..., description="Which guideline flagged this")
|
| 186 |
guideline_text: str = Field(..., description="What the guideline recommends")
|
| 187 |
patient_data: str = Field(..., description="Relevant patient data that conflicts")
|
src/backend/validation/base.py
CHANGED
|
@@ -191,6 +191,7 @@ def diagnosis_in_differential(
|
|
| 191 |
# ──────────────────────────────────────────────
|
| 192 |
|
| 193 |
DATA_DIR = Path(__file__).resolve().parent / "data"
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
def ensure_data_dir():
|
|
@@ -198,16 +199,87 @@ def ensure_data_dir():
|
|
| 198 |
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 199 |
|
| 200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
def save_results(summary: ValidationSummary, filename: str = None):
|
| 202 |
"""Save validation results to JSON."""
|
| 203 |
-
|
| 204 |
-
results_dir.mkdir(parents=True, exist_ok=True)
|
| 205 |
|
| 206 |
if filename is None:
|
| 207 |
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
| 208 |
filename = f"{summary.dataset}_{ts}.json"
|
| 209 |
|
| 210 |
-
path =
|
| 211 |
|
| 212 |
# Convert to serializable dict
|
| 213 |
data = {
|
|
@@ -218,19 +290,7 @@ def save_results(summary: ValidationSummary, filename: str = None):
|
|
| 218 |
"metrics": summary.metrics,
|
| 219 |
"run_duration_sec": summary.run_duration_sec,
|
| 220 |
"timestamp": summary.timestamp,
|
| 221 |
-
"per_case": [
|
| 222 |
-
{
|
| 223 |
-
"case_id": r.case_id,
|
| 224 |
-
"success": r.success,
|
| 225 |
-
"scores": r.scores,
|
| 226 |
-
"pipeline_time_ms": r.pipeline_time_ms,
|
| 227 |
-
"step_results": r.step_results,
|
| 228 |
-
"report_summary": r.report_summary,
|
| 229 |
-
"error": r.error,
|
| 230 |
-
"details": r.details,
|
| 231 |
-
}
|
| 232 |
-
for r in summary.per_case
|
| 233 |
-
],
|
| 234 |
}
|
| 235 |
|
| 236 |
path.write_text(json.dumps(data, indent=2, default=str))
|
|
|
|
| 191 |
# ──────────────────────────────────────────────
|
| 192 |
|
| 193 |
DATA_DIR = Path(__file__).resolve().parent / "data"
|
| 194 |
+
RESULTS_DIR = Path(__file__).resolve().parent / "results"
|
| 195 |
|
| 196 |
|
| 197 |
def ensure_data_dir():
|
|
|
|
| 199 |
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 200 |
|
| 201 |
|
| 202 |
+
def _result_to_dict(r: ValidationResult) -> dict:
|
| 203 |
+
"""Convert a ValidationResult to a serialisable dict."""
|
| 204 |
+
return {
|
| 205 |
+
"case_id": r.case_id,
|
| 206 |
+
"source_dataset": r.source_dataset,
|
| 207 |
+
"success": r.success,
|
| 208 |
+
"scores": r.scores,
|
| 209 |
+
"pipeline_time_ms": r.pipeline_time_ms,
|
| 210 |
+
"step_results": r.step_results,
|
| 211 |
+
"report_summary": r.report_summary,
|
| 212 |
+
"error": r.error,
|
| 213 |
+
"details": r.details,
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ──────────────────────────────────────────────
|
| 218 |
+
# Incremental checkpoint (JSONL)
|
| 219 |
+
# ──────────────────────────────────────────────
|
| 220 |
+
|
| 221 |
+
def checkpoint_path(dataset: str) -> Path:
|
| 222 |
+
"""Return the path to the checkpoint JSONL for *dataset*."""
|
| 223 |
+
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 224 |
+
return RESULTS_DIR / f"{dataset}_checkpoint.jsonl"
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def save_incremental(result: ValidationResult, dataset: str) -> None:
|
| 228 |
+
"""Append a single case result to the checkpoint JSONL file."""
|
| 229 |
+
path = checkpoint_path(dataset)
|
| 230 |
+
with open(path, "a", encoding="utf-8") as f:
|
| 231 |
+
f.write(json.dumps(_result_to_dict(result), default=str) + "\n")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def load_checkpoint(dataset: str) -> List[ValidationResult]:
|
| 235 |
+
"""
|
| 236 |
+
Load previously-completed results from the checkpoint file.
|
| 237 |
+
|
| 238 |
+
Returns a list of ValidationResult objects (may be empty).
|
| 239 |
+
"""
|
| 240 |
+
path = checkpoint_path(dataset)
|
| 241 |
+
if not path.exists():
|
| 242 |
+
return []
|
| 243 |
+
|
| 244 |
+
results: List[ValidationResult] = []
|
| 245 |
+
for line in path.read_text(encoding="utf-8").strip().split("\n"):
|
| 246 |
+
if not line.strip():
|
| 247 |
+
continue
|
| 248 |
+
d = json.loads(line)
|
| 249 |
+
results.append(ValidationResult(
|
| 250 |
+
case_id=d["case_id"],
|
| 251 |
+
source_dataset=d.get("source_dataset", dataset),
|
| 252 |
+
success=d["success"],
|
| 253 |
+
scores=d["scores"],
|
| 254 |
+
pipeline_time_ms=d.get("pipeline_time_ms", 0),
|
| 255 |
+
step_results=d.get("step_results", {}),
|
| 256 |
+
report_summary=d.get("report_summary"),
|
| 257 |
+
error=d.get("error"),
|
| 258 |
+
details=d.get("details", {}),
|
| 259 |
+
))
|
| 260 |
+
return results
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def clear_checkpoint(dataset: str) -> None:
|
| 264 |
+
"""Delete checkpoint file for a fresh run."""
|
| 265 |
+
path = checkpoint_path(dataset)
|
| 266 |
+
if path.exists():
|
| 267 |
+
path.unlink()
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# ──────────────────────────────────────────────
|
| 271 |
+
# Final results save
|
| 272 |
+
# ──────────────────────────────────────────────
|
| 273 |
+
|
| 274 |
def save_results(summary: ValidationSummary, filename: str = None):
|
| 275 |
"""Save validation results to JSON."""
|
| 276 |
+
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 277 |
|
| 278 |
if filename is None:
|
| 279 |
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
| 280 |
filename = f"{summary.dataset}_{ts}.json"
|
| 281 |
|
| 282 |
+
path = RESULTS_DIR / filename
|
| 283 |
|
| 284 |
# Convert to serializable dict
|
| 285 |
data = {
|
|
|
|
| 290 |
"metrics": summary.metrics,
|
| 291 |
"run_duration_sec": summary.run_duration_sec,
|
| 292 |
"timestamp": summary.timestamp,
|
| 293 |
+
"per_case": [_result_to_dict(r) for r in summary.per_case],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
}
|
| 295 |
|
| 296 |
path.write_text(json.dumps(data, indent=2, default=str))
|
src/backend/validation/harness_medqa.py
CHANGED
|
@@ -29,12 +29,15 @@ from validation.base import (
|
|
| 29 |
ValidationCase,
|
| 30 |
ValidationResult,
|
| 31 |
ValidationSummary,
|
|
|
|
| 32 |
diagnosis_in_differential,
|
| 33 |
ensure_data_dir,
|
| 34 |
fuzzy_match,
|
|
|
|
| 35 |
normalize_text,
|
| 36 |
print_summary,
|
| 37 |
run_cds_pipeline,
|
|
|
|
| 38 |
save_results,
|
| 39 |
)
|
| 40 |
|
|
@@ -186,6 +189,7 @@ async def validate_medqa(
|
|
| 186 |
include_drug_check: bool = False,
|
| 187 |
include_guidelines: bool = True,
|
| 188 |
delay_between_cases: float = 2.0,
|
|
|
|
| 189 |
) -> ValidationSummary:
|
| 190 |
"""
|
| 191 |
Run MedQA cases through the CDS pipeline and score results.
|
|
@@ -195,11 +199,27 @@ async def validate_medqa(
|
|
| 195 |
include_drug_check: Whether to run drug interaction check (slower)
|
| 196 |
include_guidelines: Whether to include guideline retrieval
|
| 197 |
delay_between_cases: Seconds to wait between cases (rate limiting)
|
|
|
|
| 198 |
"""
|
| 199 |
results: List[ValidationResult] = []
|
| 200 |
start_time = time.time()
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
for i, case in enumerate(cases):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
print(f"\n [{i+1}/{len(cases)}] {case.case_id}: ", end="", flush=True)
|
| 204 |
|
| 205 |
case_start = time.monotonic()
|
|
@@ -257,7 +277,7 @@ async def validate_medqa(
|
|
| 257 |
details = {"correct_answer": correct_answer, "error": error}
|
| 258 |
print(f"✗ FAILED: {error[:80] if error else 'unknown'}")
|
| 259 |
|
| 260 |
-
|
| 261 |
case_id=case.case_id,
|
| 262 |
source_dataset="medqa",
|
| 263 |
success=report is not None,
|
|
@@ -267,7 +287,9 @@ async def validate_medqa(
|
|
| 267 |
report_summary=report.patient_summary[:200] if report else None,
|
| 268 |
error=error,
|
| 269 |
details=details,
|
| 270 |
-
)
|
|
|
|
|
|
|
| 271 |
|
| 272 |
# Rate limit
|
| 273 |
if i < len(cases) - 1:
|
|
|
|
| 29 |
ValidationCase,
|
| 30 |
ValidationResult,
|
| 31 |
ValidationSummary,
|
| 32 |
+
clear_checkpoint,
|
| 33 |
diagnosis_in_differential,
|
| 34 |
ensure_data_dir,
|
| 35 |
fuzzy_match,
|
| 36 |
+
load_checkpoint,
|
| 37 |
normalize_text,
|
| 38 |
print_summary,
|
| 39 |
run_cds_pipeline,
|
| 40 |
+
save_incremental,
|
| 41 |
save_results,
|
| 42 |
)
|
| 43 |
|
|
|
|
| 189 |
include_drug_check: bool = False,
|
| 190 |
include_guidelines: bool = True,
|
| 191 |
delay_between_cases: float = 2.0,
|
| 192 |
+
resume: bool = False,
|
| 193 |
) -> ValidationSummary:
|
| 194 |
"""
|
| 195 |
Run MedQA cases through the CDS pipeline and score results.
|
|
|
|
| 199 |
include_drug_check: Whether to run drug interaction check (slower)
|
| 200 |
include_guidelines: Whether to include guideline retrieval
|
| 201 |
delay_between_cases: Seconds to wait between cases (rate limiting)
|
| 202 |
+
resume: If True, skip cases already in checkpoint and continue
|
| 203 |
"""
|
| 204 |
results: List[ValidationResult] = []
|
| 205 |
start_time = time.time()
|
| 206 |
|
| 207 |
+
# Resume support: load completed cases from checkpoint
|
| 208 |
+
completed_ids: set = set()
|
| 209 |
+
if resume:
|
| 210 |
+
prior = load_checkpoint("medqa")
|
| 211 |
+
if prior:
|
| 212 |
+
results.extend(prior)
|
| 213 |
+
completed_ids = {r.case_id for r in prior}
|
| 214 |
+
print(f" Resuming: {len(prior)} cases loaded from checkpoint, {len(cases) - len(completed_ids)} remaining")
|
| 215 |
+
else:
|
| 216 |
+
clear_checkpoint("medqa")
|
| 217 |
+
|
| 218 |
for i, case in enumerate(cases):
|
| 219 |
+
if case.case_id in completed_ids:
|
| 220 |
+
print(f"\n [{i+1}/{len(cases)}] {case.case_id}: (cached) skipped")
|
| 221 |
+
continue
|
| 222 |
+
|
| 223 |
print(f"\n [{i+1}/{len(cases)}] {case.case_id}: ", end="", flush=True)
|
| 224 |
|
| 225 |
case_start = time.monotonic()
|
|
|
|
| 277 |
details = {"correct_answer": correct_answer, "error": error}
|
| 278 |
print(f"✗ FAILED: {error[:80] if error else 'unknown'}")
|
| 279 |
|
| 280 |
+
result = ValidationResult(
|
| 281 |
case_id=case.case_id,
|
| 282 |
source_dataset="medqa",
|
| 283 |
success=report is not None,
|
|
|
|
| 287 |
report_summary=report.patient_summary[:200] if report else None,
|
| 288 |
error=error,
|
| 289 |
details=details,
|
| 290 |
+
)
|
| 291 |
+
results.append(result)
|
| 292 |
+
save_incremental(result, "medqa") # checkpoint after every case
|
| 293 |
|
| 294 |
# Rate limit
|
| 295 |
if i < len(cases) - 1:
|
src/backend/validation/harness_mtsamples.py
CHANGED
|
@@ -33,11 +33,14 @@ from validation.base import (
|
|
| 33 |
ValidationCase,
|
| 34 |
ValidationResult,
|
| 35 |
ValidationSummary,
|
|
|
|
| 36 |
ensure_data_dir,
|
| 37 |
fuzzy_match,
|
|
|
|
| 38 |
normalize_text,
|
| 39 |
print_summary,
|
| 40 |
run_cds_pipeline,
|
|
|
|
| 41 |
save_results,
|
| 42 |
)
|
| 43 |
|
|
@@ -238,6 +241,7 @@ async def validate_mtsamples(
|
|
| 238 |
include_drug_check: bool = True,
|
| 239 |
include_guidelines: bool = True,
|
| 240 |
delay_between_cases: float = 2.0,
|
|
|
|
| 241 |
) -> ValidationSummary:
|
| 242 |
"""
|
| 243 |
Run MTSamples cases through the CDS pipeline and score results.
|
|
@@ -245,8 +249,23 @@ async def validate_mtsamples(
|
|
| 245 |
results: List[ValidationResult] = []
|
| 246 |
start_time = time.time()
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
for i, case in enumerate(cases):
|
| 249 |
specialty = case.ground_truth.get("specialty", "?")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty}): ", end="", flush=True)
|
| 251 |
|
| 252 |
case_start = time.monotonic()
|
|
@@ -321,7 +340,7 @@ async def validate_mtsamples(
|
|
| 321 |
details = {"specialty": specialty, "error": error}
|
| 322 |
print(f"✗ FAILED: {error[:80] if error else 'unknown'}")
|
| 323 |
|
| 324 |
-
|
| 325 |
case_id=case.case_id,
|
| 326 |
source_dataset="mtsamples",
|
| 327 |
success=report is not None,
|
|
@@ -331,7 +350,9 @@ async def validate_mtsamples(
|
|
| 331 |
report_summary=report.patient_summary[:200] if report else None,
|
| 332 |
error=error,
|
| 333 |
details=details,
|
| 334 |
-
)
|
|
|
|
|
|
|
| 335 |
|
| 336 |
if i < len(cases) - 1:
|
| 337 |
await asyncio.sleep(delay_between_cases)
|
|
|
|
| 33 |
ValidationCase,
|
| 34 |
ValidationResult,
|
| 35 |
ValidationSummary,
|
| 36 |
+
clear_checkpoint,
|
| 37 |
ensure_data_dir,
|
| 38 |
fuzzy_match,
|
| 39 |
+
load_checkpoint,
|
| 40 |
normalize_text,
|
| 41 |
print_summary,
|
| 42 |
run_cds_pipeline,
|
| 43 |
+
save_incremental,
|
| 44 |
save_results,
|
| 45 |
)
|
| 46 |
|
|
|
|
| 241 |
include_drug_check: bool = True,
|
| 242 |
include_guidelines: bool = True,
|
| 243 |
delay_between_cases: float = 2.0,
|
| 244 |
+
resume: bool = False,
|
| 245 |
) -> ValidationSummary:
|
| 246 |
"""
|
| 247 |
Run MTSamples cases through the CDS pipeline and score results.
|
|
|
|
| 249 |
results: List[ValidationResult] = []
|
| 250 |
start_time = time.time()
|
| 251 |
|
| 252 |
+
# Resume support
|
| 253 |
+
completed_ids: set = set()
|
| 254 |
+
if resume:
|
| 255 |
+
prior = load_checkpoint("mtsamples")
|
| 256 |
+
if prior:
|
| 257 |
+
results.extend(prior)
|
| 258 |
+
completed_ids = {r.case_id for r in prior}
|
| 259 |
+
print(f" Resuming: {len(prior)} cases loaded from checkpoint, {len(cases) - len(completed_ids)} remaining")
|
| 260 |
+
else:
|
| 261 |
+
clear_checkpoint("mtsamples")
|
| 262 |
+
|
| 263 |
for i, case in enumerate(cases):
|
| 264 |
specialty = case.ground_truth.get("specialty", "?")
|
| 265 |
+
if case.case_id in completed_ids:
|
| 266 |
+
print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty}): (cached) skipped")
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty}): ", end="", flush=True)
|
| 270 |
|
| 271 |
case_start = time.monotonic()
|
|
|
|
| 340 |
details = {"specialty": specialty, "error": error}
|
| 341 |
print(f"✗ FAILED: {error[:80] if error else 'unknown'}")
|
| 342 |
|
| 343 |
+
result = ValidationResult(
|
| 344 |
case_id=case.case_id,
|
| 345 |
source_dataset="mtsamples",
|
| 346 |
success=report is not None,
|
|
|
|
| 350 |
report_summary=report.patient_summary[:200] if report else None,
|
| 351 |
error=error,
|
| 352 |
details=details,
|
| 353 |
+
)
|
| 354 |
+
results.append(result)
|
| 355 |
+
save_incremental(result, "mtsamples") # checkpoint after every case
|
| 356 |
|
| 357 |
if i < len(cases) - 1:
|
| 358 |
await asyncio.sleep(delay_between_cases)
|
src/backend/validation/harness_pmc.py
CHANGED
|
@@ -31,12 +31,15 @@ from validation.base import (
|
|
| 31 |
ValidationCase,
|
| 32 |
ValidationResult,
|
| 33 |
ValidationSummary,
|
|
|
|
| 34 |
diagnosis_in_differential,
|
| 35 |
ensure_data_dir,
|
| 36 |
fuzzy_match,
|
|
|
|
| 37 |
normalize_text,
|
| 38 |
print_summary,
|
| 39 |
run_cds_pipeline,
|
|
|
|
| 40 |
save_results,
|
| 41 |
)
|
| 42 |
|
|
@@ -322,6 +325,7 @@ async def validate_pmc(
|
|
| 322 |
include_drug_check: bool = True,
|
| 323 |
include_guidelines: bool = True,
|
| 324 |
delay_between_cases: float = 2.0,
|
|
|
|
| 325 |
) -> ValidationSummary:
|
| 326 |
"""
|
| 327 |
Run PMC case reports through the CDS pipeline and score results.
|
|
@@ -329,9 +333,24 @@ async def validate_pmc(
|
|
| 329 |
results: List[ValidationResult] = []
|
| 330 |
start_time = time.time()
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
for i, case in enumerate(cases):
|
| 333 |
dx = case.ground_truth.get("diagnosis", "?")
|
| 334 |
specialty = case.ground_truth.get("specialty", "?")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty} — {dx[:40]}): ", end="", flush=True)
|
| 336 |
|
| 337 |
case_start = time.monotonic()
|
|
@@ -393,7 +412,7 @@ async def validate_pmc(
|
|
| 393 |
details = {"target_diagnosis": target_diagnosis, "error": error}
|
| 394 |
print(f"✗ FAILED: {error[:80] if error else 'unknown'}")
|
| 395 |
|
| 396 |
-
|
| 397 |
case_id=case.case_id,
|
| 398 |
source_dataset="pmc",
|
| 399 |
success=report is not None,
|
|
@@ -403,7 +422,9 @@ async def validate_pmc(
|
|
| 403 |
report_summary=report.patient_summary[:200] if report else None,
|
| 404 |
error=error,
|
| 405 |
details=details,
|
| 406 |
-
)
|
|
|
|
|
|
|
| 407 |
|
| 408 |
if i < len(cases) - 1:
|
| 409 |
await asyncio.sleep(delay_between_cases)
|
|
|
|
| 31 |
ValidationCase,
|
| 32 |
ValidationResult,
|
| 33 |
ValidationSummary,
|
| 34 |
+
clear_checkpoint,
|
| 35 |
diagnosis_in_differential,
|
| 36 |
ensure_data_dir,
|
| 37 |
fuzzy_match,
|
| 38 |
+
load_checkpoint,
|
| 39 |
normalize_text,
|
| 40 |
print_summary,
|
| 41 |
run_cds_pipeline,
|
| 42 |
+
save_incremental,
|
| 43 |
save_results,
|
| 44 |
)
|
| 45 |
|
|
|
|
| 325 |
include_drug_check: bool = True,
|
| 326 |
include_guidelines: bool = True,
|
| 327 |
delay_between_cases: float = 2.0,
|
| 328 |
+
resume: bool = False,
|
| 329 |
) -> ValidationSummary:
|
| 330 |
"""
|
| 331 |
Run PMC case reports through the CDS pipeline and score results.
|
|
|
|
| 333 |
results: List[ValidationResult] = []
|
| 334 |
start_time = time.time()
|
| 335 |
|
| 336 |
+
# Resume support
|
| 337 |
+
completed_ids: set = set()
|
| 338 |
+
if resume:
|
| 339 |
+
prior = load_checkpoint("pmc")
|
| 340 |
+
if prior:
|
| 341 |
+
results.extend(prior)
|
| 342 |
+
completed_ids = {r.case_id for r in prior}
|
| 343 |
+
print(f" Resuming: {len(prior)} cases loaded from checkpoint, {len(cases) - len(completed_ids)} remaining")
|
| 344 |
+
else:
|
| 345 |
+
clear_checkpoint("pmc")
|
| 346 |
+
|
| 347 |
for i, case in enumerate(cases):
|
| 348 |
dx = case.ground_truth.get("diagnosis", "?")
|
| 349 |
specialty = case.ground_truth.get("specialty", "?")
|
| 350 |
+
if case.case_id in completed_ids:
|
| 351 |
+
print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty}): (cached) skipped")
|
| 352 |
+
continue
|
| 353 |
+
|
| 354 |
print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty} — {dx[:40]}): ", end="", flush=True)
|
| 355 |
|
| 356 |
case_start = time.monotonic()
|
|
|
|
| 412 |
details = {"target_diagnosis": target_diagnosis, "error": error}
|
| 413 |
print(f"✗ FAILED: {error[:80] if error else 'unknown'}")
|
| 414 |
|
| 415 |
+
result = ValidationResult(
|
| 416 |
case_id=case.case_id,
|
| 417 |
source_dataset="pmc",
|
| 418 |
success=report is not None,
|
|
|
|
| 422 |
report_summary=report.patient_summary[:200] if report else None,
|
| 423 |
error=error,
|
| 424 |
details=details,
|
| 425 |
+
)
|
| 426 |
+
results.append(result)
|
| 427 |
+
save_incremental(result, "pmc") # checkpoint after every case
|
| 428 |
|
| 429 |
if i < len(cases) - 1:
|
| 430 |
await asyncio.sleep(delay_between_cases)
|
src/backend/validation/run_validation.py
CHANGED
|
@@ -48,6 +48,7 @@ async def run_all_validations(
|
|
| 48 |
include_guidelines: bool = True,
|
| 49 |
delay: float = 2.0,
|
| 50 |
fetch_only: bool = False,
|
|
|
|
| 51 |
) -> dict:
|
| 52 |
"""
|
| 53 |
Run validation against selected datasets.
|
|
@@ -73,6 +74,7 @@ async def run_all_validations(
|
|
| 73 |
include_drug_check=include_drug_check,
|
| 74 |
include_guidelines=include_guidelines,
|
| 75 |
delay_between_cases=delay,
|
|
|
|
| 76 |
)
|
| 77 |
print_summary(summary)
|
| 78 |
save_results(summary)
|
|
@@ -94,6 +96,7 @@ async def run_all_validations(
|
|
| 94 |
include_drug_check=include_drug_check,
|
| 95 |
include_guidelines=include_guidelines,
|
| 96 |
delay_between_cases=delay,
|
|
|
|
| 97 |
)
|
| 98 |
print_summary(summary)
|
| 99 |
save_results(summary)
|
|
@@ -115,6 +118,7 @@ async def run_all_validations(
|
|
| 115 |
include_drug_check=include_drug_check,
|
| 116 |
include_guidelines=include_guidelines,
|
| 117 |
delay_between_cases=delay,
|
|
|
|
| 118 |
)
|
| 119 |
print_summary(summary)
|
| 120 |
save_results(summary)
|
|
@@ -235,6 +239,7 @@ Examples:
|
|
| 235 |
config_group.add_argument("--delay", type=float, default=2.0, help="Delay between cases in seconds (default: 2.0)")
|
| 236 |
config_group.add_argument("--no-drugs", action="store_true", help="Skip drug interaction checks")
|
| 237 |
config_group.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
|
|
|
|
| 238 |
config_group.add_argument("--fetch-only", action="store_true", help="Only download data, don't run pipeline")
|
| 239 |
|
| 240 |
args = parser.parse_args()
|
|
@@ -254,6 +259,7 @@ Examples:
|
|
| 254 |
print(f" Cases/dataset: {args.max_cases}")
|
| 255 |
print(f" Drug check: {'Yes' if not args.no_drugs else 'No'}")
|
| 256 |
print(f" Guidelines: {'Yes' if not args.no_guidelines else 'No'}")
|
|
|
|
| 257 |
print(f" Fetch only: {'Yes' if args.fetch_only else 'No'}")
|
| 258 |
|
| 259 |
asyncio.run(run_all_validations(
|
|
@@ -266,6 +272,7 @@ Examples:
|
|
| 266 |
include_guidelines=not args.no_guidelines,
|
| 267 |
delay=args.delay,
|
| 268 |
fetch_only=args.fetch_only,
|
|
|
|
| 269 |
))
|
| 270 |
|
| 271 |
|
|
|
|
| 48 |
include_guidelines: bool = True,
|
| 49 |
delay: float = 2.0,
|
| 50 |
fetch_only: bool = False,
|
| 51 |
+
resume: bool = False,
|
| 52 |
) -> dict:
|
| 53 |
"""
|
| 54 |
Run validation against selected datasets.
|
|
|
|
| 74 |
include_drug_check=include_drug_check,
|
| 75 |
include_guidelines=include_guidelines,
|
| 76 |
delay_between_cases=delay,
|
| 77 |
+
resume=resume,
|
| 78 |
)
|
| 79 |
print_summary(summary)
|
| 80 |
save_results(summary)
|
|
|
|
| 96 |
include_drug_check=include_drug_check,
|
| 97 |
include_guidelines=include_guidelines,
|
| 98 |
delay_between_cases=delay,
|
| 99 |
+
resume=resume,
|
| 100 |
)
|
| 101 |
print_summary(summary)
|
| 102 |
save_results(summary)
|
|
|
|
| 118 |
include_drug_check=include_drug_check,
|
| 119 |
include_guidelines=include_guidelines,
|
| 120 |
delay_between_cases=delay,
|
| 121 |
+
resume=resume,
|
| 122 |
)
|
| 123 |
print_summary(summary)
|
| 124 |
save_results(summary)
|
|
|
|
| 239 |
config_group.add_argument("--delay", type=float, default=2.0, help="Delay between cases in seconds (default: 2.0)")
|
| 240 |
config_group.add_argument("--no-drugs", action="store_true", help="Skip drug interaction checks")
|
| 241 |
config_group.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
|
| 242 |
+
config_group.add_argument("--resume", action="store_true", help="Resume from checkpoint (skip already-completed cases)")
|
| 243 |
config_group.add_argument("--fetch-only", action="store_true", help="Only download data, don't run pipeline")
|
| 244 |
|
| 245 |
args = parser.parse_args()
|
|
|
|
| 259 |
print(f" Cases/dataset: {args.max_cases}")
|
| 260 |
print(f" Drug check: {'Yes' if not args.no_drugs else 'No'}")
|
| 261 |
print(f" Guidelines: {'Yes' if not args.no_guidelines else 'No'}")
|
| 262 |
+
print(f" Resume: {'Yes' if args.resume else 'No'}")
|
| 263 |
print(f" Fetch only: {'Yes' if args.fetch_only else 'No'}")
|
| 264 |
|
| 265 |
asyncio.run(run_all_validations(
|
|
|
|
| 272 |
include_guidelines=not args.no_guidelines,
|
| 273 |
delay=args.delay,
|
| 274 |
fetch_only=args.fetch_only,
|
| 275 |
+
resume=args.resume,
|
| 276 |
))
|
| 277 |
|
| 278 |
|