|
|
""" |
|
|
PMC Case Reports fetcher and validation harness. |
|
|
|
|
|
Fetches published clinical case reports from PubMed Central and evaluates |
|
|
the CDS pipeline's diagnostic accuracy against gold-standard diagnoses. |
|
|
|
|
|
Source: NCBI PubMed / PubMed Central (E-utilities API) |
|
|
Format: XML abstracts with case presentations and final diagnoses |
|
|
|
|
|
Metrics: |
|
|
- diagnostic_accuracy: Correct diagnosis appears in differential |
|
|
- top3_accuracy: Correct diagnosis in top 3 |
|
|
- parse_success_rate: Pipeline completed without crashing |
|
|
- has_recommendations: Report includes actionable next steps |
|
|
""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import random |
|
|
import re |
|
|
import time |
|
|
import xml.etree.ElementTree as ET |
|
|
from pathlib import Path |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
import httpx |
|
|
|
|
|
from validation.base import ( |
|
|
DATA_DIR, |
|
|
ValidationCase, |
|
|
ValidationResult, |
|
|
ValidationSummary, |
|
|
clear_checkpoint, |
|
|
diagnosis_in_differential, |
|
|
ensure_data_dir, |
|
|
fuzzy_match, |
|
|
load_checkpoint, |
|
|
normalize_text, |
|
|
print_summary, |
|
|
run_cds_pipeline, |
|
|
save_incremental, |
|
|
save_results, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EUTILS_BASE = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils" |
|
|
ESEARCH_URL = f"{EUTILS_BASE}/esearch.fcgi" |
|
|
EFETCH_URL = f"{EUTILS_BASE}/efetch.fcgi" |
|
|
|
|
|
|
|
|
|
|
|
CASE_REPORT_QUERIES = [ |
|
|
('"case report"[Title] AND "myocardial infarction"[Title] AND diagnosis', "Cardiology"), |
|
|
('"case report"[Title] AND "pneumonia"[Title] AND diagnosis', "Pulmonology"), |
|
|
('"case report"[Title] AND "diabetic ketoacidosis"[Title]', "Endocrinology"), |
|
|
('"case report"[Title] AND "stroke"[Title] AND diagnosis', "Neurology"), |
|
|
('"case report"[Title] AND "appendicitis"[Title] AND diagnosis', "Surgery"), |
|
|
('"case report"[Title] AND "pulmonary embolism"[Title]', "Pulmonology"), |
|
|
('"case report"[Title] AND "sepsis"[Title] AND management', "Critical Care"), |
|
|
('"case report"[Title] AND "heart failure"[Title] AND management', "Cardiology"), |
|
|
('"case report"[Title] AND "pancreatitis"[Title] AND diagnosis', "Gastroenterology"), |
|
|
('"case report"[Title] AND "meningitis"[Title] AND diagnosis', "Neurology/ID"), |
|
|
('"case report"[Title] AND "urinary tract infection"[Title]', "Urology/ID"), |
|
|
('"case report"[Title] AND "thyroid"[Title] AND "nodule"', "Endocrinology"), |
|
|
('"case report"[Title] AND "deep vein thrombosis"[Title]', "Hematology"), |
|
|
('"case report"[Title] AND "anaphylaxis"[Title]', "Allergy/EM"), |
|
|
('"case report"[Title] AND "renal failure"[Title] AND acute', "Nephrology"), |
|
|
('"case report"[Title] AND "liver cirrhosis"[Title]', "Hepatology"), |
|
|
('"case report"[Title] AND "asthma"[Title] AND exacerbation', "Pulmonology"), |
|
|
('"case report"[Title] AND "seizure"[Title] AND diagnosis', "Neurology"), |
|
|
('"case report"[Title] AND "hypoglycemia"[Title]', "Endocrinology"), |
|
|
('"case report"[Title] AND "gastrointestinal bleeding"[Title]', "Gastroenterology"), |
|
|
] |
|
|
|
|
|
|
|
|
async def fetch_pmc_cases( |
|
|
max_cases: int = 20, |
|
|
seed: int = 42, |
|
|
) -> List[ValidationCase]: |
|
|
""" |
|
|
Fetch case reports from PubMed and convert to ValidationCase objects. |
|
|
|
|
|
Uses PubMed E-utilities to search for case reports with clear diagnoses, |
|
|
then extracts the clinical presentation and diagnosis from abstracts. |
|
|
|
|
|
Args: |
|
|
max_cases: Maximum number of cases to fetch |
|
|
seed: Random seed for reproducible selection |
|
|
""" |
|
|
ensure_data_dir() |
|
|
cache_path = DATA_DIR / "pmc_cases.json" |
|
|
|
|
|
if cache_path.exists(): |
|
|
print(f" Loading PMC cases from cache: {cache_path}") |
|
|
cached = json.loads(cache_path.read_text(encoding="utf-8")) |
|
|
cases = [ValidationCase(**c) for c in cached] |
|
|
if len(cases) >= max_cases: |
|
|
random.seed(seed) |
|
|
return random.sample(cases, min(max_cases, len(cases))) |
|
|
|
|
|
|
|
|
print(f" Fetching case reports from PubMed...") |
|
|
cases = await _fetch_from_pubmed(max_cases, seed) |
|
|
|
|
|
if cases: |
|
|
|
|
|
cached_data = [ |
|
|
{ |
|
|
"case_id": c.case_id, |
|
|
"source_dataset": c.source_dataset, |
|
|
"input_text": c.input_text, |
|
|
"ground_truth": c.ground_truth, |
|
|
"metadata": c.metadata, |
|
|
} |
|
|
for c in cases |
|
|
] |
|
|
cache_path.write_text(json.dumps(cached_data, indent=2), encoding="utf-8") |
|
|
print(f" Cached {len(cases)} PMC cases to {cache_path}") |
|
|
|
|
|
print(f" Loaded {len(cases)} PMC case reports") |
|
|
return cases |
|
|
|
|
|
|
|
|
async def _fetch_from_pubmed(max_cases: int, seed: int) -> List[ValidationCase]: |
|
|
"""Fetch case reports via PubMed E-utilities.""" |
|
|
cases = [] |
|
|
random.seed(seed) |
|
|
queries = random.sample(CASE_REPORT_QUERIES, min(max_cases, len(CASE_REPORT_QUERIES))) |
|
|
|
|
|
async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: |
|
|
for query_text, specialty in queries: |
|
|
if len(cases) >= max_cases: |
|
|
break |
|
|
|
|
|
try: |
|
|
|
|
|
pmids = await _esearch(client, query_text, retmax=3) |
|
|
if not pmids: |
|
|
continue |
|
|
|
|
|
|
|
|
for pmid in pmids[:1]: |
|
|
abstract_data = await _efetch_abstract(client, pmid) |
|
|
if not abstract_data: |
|
|
continue |
|
|
|
|
|
title, abstract = abstract_data |
|
|
|
|
|
|
|
|
presentation, diagnosis = _extract_case_and_diagnosis(title, abstract, query_text) |
|
|
if not presentation or not diagnosis: |
|
|
continue |
|
|
|
|
|
cases.append(ValidationCase( |
|
|
case_id=f"pmc_{pmid}", |
|
|
source_dataset="pmc", |
|
|
input_text=presentation, |
|
|
ground_truth={ |
|
|
"diagnosis": diagnosis, |
|
|
"specialty": specialty, |
|
|
"title": title, |
|
|
}, |
|
|
metadata={ |
|
|
"pmid": pmid, |
|
|
"full_abstract": abstract, |
|
|
}, |
|
|
)) |
|
|
|
|
|
if len(cases) >= max_cases: |
|
|
break |
|
|
|
|
|
|
|
|
await asyncio.sleep(0.5) |
|
|
|
|
|
except Exception as e: |
|
|
print(f" Warning: Query failed '{query_text[:40]}...': {e}") |
|
|
continue |
|
|
|
|
|
return cases |
|
|
|
|
|
|
|
|
async def _esearch(client: httpx.AsyncClient, query: str, retmax: int = 3) -> List[str]: |
|
|
"""Search PubMed and return PMIDs.""" |
|
|
params = { |
|
|
"db": "pubmed", |
|
|
"term": query, |
|
|
"retmax": retmax, |
|
|
"retmode": "json", |
|
|
"sort": "relevance", |
|
|
} |
|
|
r = await client.get(ESEARCH_URL, params=params) |
|
|
r.raise_for_status() |
|
|
data = r.json() |
|
|
return data.get("esearchresult", {}).get("idlist", []) |
|
|
|
|
|
|
|
|
async def _efetch_abstract(client: httpx.AsyncClient, pmid: str) -> Optional[Tuple[str, str]]: |
|
|
"""Fetch the title and abstract for a PMID.""" |
|
|
params = { |
|
|
"db": "pubmed", |
|
|
"id": pmid, |
|
|
"retmode": "xml", |
|
|
} |
|
|
r = await client.get(EFETCH_URL, params=params) |
|
|
r.raise_for_status() |
|
|
|
|
|
try: |
|
|
root = ET.fromstring(r.text) |
|
|
|
|
|
|
|
|
title_el = root.find(".//ArticleTitle") |
|
|
title = title_el.text if title_el is not None and title_el.text else "" |
|
|
|
|
|
|
|
|
abstract_parts = [] |
|
|
for abs_text in root.findall(".//AbstractText"): |
|
|
label = abs_text.get("Label", "") |
|
|
text = abs_text.text or "" |
|
|
|
|
|
full_text = (abs_text.text or "") + "".join( |
|
|
(child.text or "") + (child.tail or "") for child in abs_text |
|
|
) |
|
|
if label: |
|
|
abstract_parts.append(f"{label}: {full_text.strip()}") |
|
|
else: |
|
|
abstract_parts.append(full_text.strip()) |
|
|
|
|
|
abstract = " ".join(abstract_parts) |
|
|
|
|
|
if len(abstract) < 100: |
|
|
return None |
|
|
|
|
|
return title, abstract |
|
|
|
|
|
except ET.ParseError: |
|
|
return None |
|
|
|
|
|
|
|
|
def _extract_case_and_diagnosis( |
|
|
title: str, abstract: str, search_query: str |
|
|
) -> Tuple[Optional[str], Optional[str]]: |
|
|
""" |
|
|
Extract the clinical presentation and final diagnosis from a case report abstract. |
|
|
|
|
|
Strategy: |
|
|
1. Try structured abstract sections (CASE PRESENTATION, DIAGNOSIS, etc.) |
|
|
2. Extract diagnosis from the title (common pattern: "A case of [diagnosis]") |
|
|
3. Fall back to using the search condition as the expected diagnosis |
|
|
""" |
|
|
|
|
|
diagnosis = None |
|
|
title_patterns = [ |
|
|
r"case (?:report )?of (.+?)(?:\.|:|$)", |
|
|
r"presenting (?:as|with) (.+?)(?:\.|:|$)", |
|
|
r"diagnosed (?:as|with) (.+?)(?:\.|:|$)", |
|
|
r"rare case of (.+?)(?:\.|:|$)", |
|
|
r"unusual (?:case|presentation) of (.+?)(?:\.|:|$)", |
|
|
|
|
|
r"^(.+?):\s*[Aa]\s*[Cc]ase\s*[Rr]eport", |
|
|
|
|
|
r"^(.+?)\s*[-ββ]\s*[Cc]ase\s*[Rr]eport", |
|
|
|
|
|
r"[Cc]ase\s+of\s+(.+?)(?:\.|:|,|$)", |
|
|
] |
|
|
for pattern in title_patterns: |
|
|
match = re.search(pattern, title, re.IGNORECASE) |
|
|
if match: |
|
|
diagnosis = match.group(1).strip() |
|
|
break |
|
|
|
|
|
if not diagnosis: |
|
|
|
|
|
|
|
|
|
|
|
matches = re.findall(r'"([^"]+)"', search_query) |
|
|
for m in matches: |
|
|
if m.lower() != "case report": |
|
|
diagnosis = m |
|
|
break |
|
|
|
|
|
if not diagnosis: |
|
|
return None, None |
|
|
|
|
|
|
|
|
diagnosis = diagnosis.strip().rstrip('.') |
|
|
|
|
|
|
|
|
|
|
|
presentation_sections = ["CASE PRESENTATION", "CASE REPORT", "CASE", "CLINICAL PRESENTATION", "HISTORY"] |
|
|
conclusion_sections = ["CONCLUSION", "DISCUSSION", "OUTCOME", "DIAGNOSIS", "RESULTS"] |
|
|
|
|
|
|
|
|
presentation = abstract |
|
|
|
|
|
|
|
|
for cs in conclusion_sections: |
|
|
pattern = re.compile(rf'\b{cs}\b[:\s]', re.IGNORECASE) |
|
|
match = pattern.search(abstract) |
|
|
if match: |
|
|
|
|
|
candidate = abstract[:match.start()].strip() |
|
|
if len(candidate) > 100: |
|
|
presentation = candidate |
|
|
break |
|
|
|
|
|
|
|
|
presentation = presentation.strip() |
|
|
if len(presentation) < 50: |
|
|
presentation = abstract |
|
|
|
|
|
return presentation, diagnosis |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def validate_pmc( |
|
|
cases: List[ValidationCase], |
|
|
include_drug_check: bool = True, |
|
|
include_guidelines: bool = True, |
|
|
delay_between_cases: float = 2.0, |
|
|
resume: bool = False, |
|
|
) -> ValidationSummary: |
|
|
""" |
|
|
Run PMC case reports through the CDS pipeline and score results. |
|
|
""" |
|
|
results: List[ValidationResult] = [] |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
completed_ids: set = set() |
|
|
if resume: |
|
|
prior = load_checkpoint("pmc") |
|
|
if prior: |
|
|
results.extend(prior) |
|
|
completed_ids = {r.case_id for r in prior} |
|
|
print(f" Resuming: {len(prior)} cases loaded from checkpoint, {len(cases) - len(completed_ids)} remaining") |
|
|
else: |
|
|
clear_checkpoint("pmc") |
|
|
|
|
|
for i, case in enumerate(cases): |
|
|
dx = case.ground_truth.get("diagnosis", "?") |
|
|
specialty = case.ground_truth.get("specialty", "?") |
|
|
if case.case_id in completed_ids: |
|
|
print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty}): (cached) skipped") |
|
|
continue |
|
|
|
|
|
print(f"\n [{i+1}/{len(cases)}] {case.case_id} ({specialty} β {dx[:40]}): ", end="", flush=True) |
|
|
|
|
|
case_start = time.monotonic() |
|
|
|
|
|
state, report, error = await run_cds_pipeline( |
|
|
patient_text=case.input_text, |
|
|
include_drug_check=include_drug_check, |
|
|
include_guidelines=include_guidelines, |
|
|
) |
|
|
|
|
|
elapsed_ms = int((time.monotonic() - case_start) * 1000) |
|
|
|
|
|
step_results = {} |
|
|
if state: |
|
|
step_results = {s.step_id: s.status.value for s in state.steps} |
|
|
|
|
|
scores = {} |
|
|
details = {} |
|
|
target_diagnosis = case.ground_truth["diagnosis"] |
|
|
|
|
|
if report: |
|
|
|
|
|
found_any, rank_any, loc_any = diagnosis_in_differential(target_diagnosis, report) |
|
|
scores["diagnostic_accuracy"] = 1.0 if found_any else 0.0 |
|
|
|
|
|
|
|
|
found_top3, rank3, loc3 = diagnosis_in_differential(target_diagnosis, report, top_n=3) |
|
|
scores["top3_accuracy"] = 1.0 if found_top3 else 0.0 |
|
|
|
|
|
|
|
|
found_top1, rank1, loc1 = diagnosis_in_differential(target_diagnosis, report, top_n=1) |
|
|
scores["top1_accuracy"] = 1.0 if found_top1 else 0.0 |
|
|
|
|
|
|
|
|
scores["parse_success"] = 1.0 |
|
|
|
|
|
|
|
|
scores["has_recommendations"] = 1.0 if len(report.suggested_next_steps) > 0 else 0.0 |
|
|
|
|
|
details = { |
|
|
"target_diagnosis": target_diagnosis, |
|
|
"top_diagnosis": report.differential_diagnosis[0].diagnosis if report.differential_diagnosis else "NONE", |
|
|
"num_diagnoses": len(report.differential_diagnosis), |
|
|
"found_at_rank": rank_any if found_any else -1, |
|
|
"all_diagnoses": [d.diagnosis for d in report.differential_diagnosis[:5]], |
|
|
} |
|
|
|
|
|
icon = "β" if found_any else "β" |
|
|
top_dx = report.differential_diagnosis[0].diagnosis if report.differential_diagnosis else "NONE" |
|
|
print(f"{icon} top1={'Y' if found_top1 else 'N'} diag={'Y' if found_any else 'N'} | top: {top_dx[:30]} ({elapsed_ms}ms)") |
|
|
else: |
|
|
scores = { |
|
|
"diagnostic_accuracy": 0.0, |
|
|
"top3_accuracy": 0.0, |
|
|
"top1_accuracy": 0.0, |
|
|
"parse_success": 0.0, |
|
|
"has_recommendations": 0.0, |
|
|
} |
|
|
details = {"target_diagnosis": target_diagnosis, "error": error} |
|
|
print(f"β FAILED: {error[:80] if error else 'unknown'}") |
|
|
|
|
|
result = ValidationResult( |
|
|
case_id=case.case_id, |
|
|
source_dataset="pmc", |
|
|
success=report is not None, |
|
|
scores=scores, |
|
|
pipeline_time_ms=elapsed_ms, |
|
|
step_results=step_results, |
|
|
report_summary=report.patient_summary[:200] if report else None, |
|
|
error=error, |
|
|
details=details, |
|
|
) |
|
|
results.append(result) |
|
|
save_incremental(result, "pmc") |
|
|
|
|
|
if i < len(cases) - 1: |
|
|
await asyncio.sleep(delay_between_cases) |
|
|
|
|
|
|
|
|
total = len(results) |
|
|
successful = sum(1 for r in results if r.success) |
|
|
|
|
|
metric_names = ["diagnostic_accuracy", "top3_accuracy", "top1_accuracy", "parse_success", "has_recommendations"] |
|
|
metrics = {} |
|
|
for m in metric_names: |
|
|
values = [r.scores.get(m, 0.0) for r in results] |
|
|
metrics[m] = sum(values) / len(values) if values else 0.0 |
|
|
|
|
|
times = [r.pipeline_time_ms for r in results if r.success] |
|
|
metrics["avg_pipeline_time_ms"] = sum(times) / len(times) if times else 0 |
|
|
|
|
|
summary = ValidationSummary( |
|
|
dataset="pmc", |
|
|
total_cases=total, |
|
|
successful_cases=successful, |
|
|
failed_cases=total - successful, |
|
|
metrics=metrics, |
|
|
per_case=results, |
|
|
run_duration_sec=time.time() - start_time, |
|
|
) |
|
|
|
|
|
return summary |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""Run PMC Case Reports validation standalone.""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="PMC Case Reports Validation") |
|
|
parser.add_argument("--max-cases", type=int, default=10, help="Number of cases to evaluate") |
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed") |
|
|
parser.add_argument("--no-drugs", action="store_true", help="Skip drug interaction check") |
|
|
parser.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval") |
|
|
parser.add_argument("--delay", type=float, default=2.0, help="Delay between cases (seconds)") |
|
|
args = parser.parse_args() |
|
|
|
|
|
print("PMC Case Reports Validation Harness") |
|
|
print("=" * 40) |
|
|
|
|
|
cases = await fetch_pmc_cases(max_cases=args.max_cases, seed=args.seed) |
|
|
summary = await validate_pmc( |
|
|
cases, |
|
|
include_drug_check=not args.no_drugs, |
|
|
include_guidelines=not args.no_guidelines, |
|
|
delay_between_cases=args.delay, |
|
|
) |
|
|
|
|
|
print_summary(summary) |
|
|
path = save_results(summary) |
|
|
print(f"Results saved to: {path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|