meddiagnostic-env / prepare_data.py
pratinavseth's picture
Initial meddiagnostic-env: sequential clinical diagnostic environment
4fb9d7c verified
"""
prepare_data.py — run LOCALLY once to generate data/cases.json.
Downloads:
zou-lab/MedCaseReasoning (train + val splits)
sauravlmx/MEDEC-MS (train split)
Outputs:
data/cases.json — 15 curated cases (COMMIT THIS)
data/icd10_synonyms.json — diagnosis normalisation map (COMMIT THIS)
Usage:
pip install datasets
python prepare_data.py
"""
from __future__ import annotations
import json
import random
import re
from pathlib import Path
from datasets import load_dataset
SEED = 42
OUTPUT_DIR = Path(__file__).parent / "data"
OUTPUT_DIR.mkdir(exist_ok=True)
random.seed(SEED)
# ---------------------------------------------------------------------------
# ICD-10 synonym map — extended
# ---------------------------------------------------------------------------
ICD10_SYNONYMS: dict[str, str] = {
# Pneumonia
"community-acquired pneumonia": "J18.9",
"community acquired pneumonia": "J18.9",
"cap": "J18.9",
"pneumonia": "J18.9",
"lobar pneumonia": "J18.1",
"bacterial pneumonia": "J15.9",
"streptococcal pneumonia": "J13",
"pneumococcal pneumonia": "J13",
# UTI / sepsis
"urinary tract infection": "N39.0",
"uti": "N39.0",
"urosepsis": "A41.51",
"sepsis due to uti": "A41.51",
"gram-negative sepsis": "A41.50",
"septicemia": "A41.9",
"sepsis": "A41.9",
# NSTEMI
"nstemi": "I21.4",
"non-st elevation myocardial infarction": "I21.4",
"non st elevation mi": "I21.4",
"acute coronary syndrome": "I24.9",
"acs": "I24.9",
"unstable angina": "I20.0",
# PE
"pulmonary embolism": "I26.99",
"pe": "I26.99",
"acute pe": "I26.99",
"pulmonary thromboembolism": "I26.99",
# AKI
"acute kidney injury": "N17.9",
"aki": "N17.9",
"acute renal failure": "N17.9",
"pre-renal aki": "N17.0",
"prerenal aki": "N17.0",
"prerenal azotemia": "N17.0",
"atn": "N17.1",
"acute tubular necrosis": "N17.1",
# Hyponatremia
"hyponatremia": "E87.1",
"siadh": "E22.2",
"syndrome of inappropriate antidiuretic hormone": "E22.2",
# DKA
"diabetic ketoacidosis": "E11.10",
"dka": "E11.10",
"euglycemic dka": "E13.10",
"euglycemic diabetic ketoacidosis": "E13.10",
"edka": "E13.10",
# CO poisoning
"carbon monoxide poisoning": "T58.01XA",
"co poisoning": "T58.01XA",
"carbon monoxide toxicity": "T58.01XA",
# HIT
"heparin-induced thrombocytopenia": "T45.515A",
"hit": "T45.515A",
"hit type 2": "T45.515A",
"heparin induced thrombocytopenia": "T45.515A",
# Aortic dissection
"aortic dissection": "I71.00",
"type a aortic dissection": "I71.01",
"type b aortic dissection": "I71.03",
# Heart failure
"heart failure": "I50.9",
"acute decompensated heart failure": "I50.9",
"adhf": "I50.9",
"congestive heart failure": "I50.9",
"chf": "I50.9",
# Hypertensive emergency
"hypertensive emergency": "I16.1",
"hypertensive urgency": "I16.0",
"hypertensive crisis": "I16.9",
# Addison's
"adrenal insufficiency": "E27.1",
"addisonian crisis": "E27.2",
"addison's disease": "E27.1",
"primary adrenal insufficiency": "E27.1",
# Meningitis
"bacterial meningitis": "G00.9",
"meningitis": "G03.9",
"viral meningitis": "G02",
"cryptococcal meningitis": "B45.1",
# Thyroid storm
"thyroid storm": "E05.51",
"thyrotoxic crisis": "E05.51",
"thyrotoxicosis": "E05.90",
}
# ---------------------------------------------------------------------------
# MEDEC error extraction helpers
# ---------------------------------------------------------------------------
def parse_medec_sentences(sentences_text: str) -> list[tuple[str, str]]:
"""Returns list of (sentence_id, sentence_text)."""
result = []
for line in sentences_text.strip().splitlines():
line = line.strip()
if not line:
continue
parts = line.split(maxsplit=1)
if len(parts) == 2:
result.append((parts[0], parts[1]))
return result
def build_soap_note(sentences_text: str) -> str:
"""Return the sentences as a clean numbered SOAP note."""
return sentences_text.strip()
# ---------------------------------------------------------------------------
# MedCaseReasoning case templates
# ---------------------------------------------------------------------------
# We hand-select 15 representative cases by matching final_diagnosis keywords.
# Priority targets per task:
# Ordered lists of specific target diagnoses (most preferred first)
EASY_TARGETS_ORDERED = [
["community-acquired pneumonia", "community acquired pneumonia"],
["myocardial infarction", "nstemi", "stemi", "acute mi"],
["sepsis", "urosepsis", "bacteremia"],
]
MEDIUM_TARGETS_ORDERED = [
["pulmonary embolism"],
["acute kidney injury", "aki", "renal failure"],
["hyponatremia"],
["heart failure", "cardiac failure", "congestive"],
["hypertensive emergency", "hypertensive urgency", "hypertensive crisis"],
]
HARD_TARGETS_ORDERED = [
["euglycemic diabetic ketoacidosis", "euglycemic dka", "edka"],
["carbon monoxide"],
["heparin-induced thrombocytopenia", "hit type"],
["aortic dissection"],
["adrenal insufficiency", "addisonian crisis", "addison"],
["meningitis"],
["diabetic ketoacidosis", "dka"], # fallback for hard
]
# Flat versions for _matches_any fallback
EASY_TARGETS = ["pneumonia", "myocardial infarction", "nstemi", "sepsis"]
MEDIUM_TARGETS = ["pulmonary embolism", "acute kidney", "hyponatremia", "heart failure", "hypertensive"]
HARD_TARGETS = ["diabetic ketoacidosis", "carbon monoxide", "heparin", "aortic dissection", "adrenal", "meningitis"]
def _matches_any(text: str, keywords: list[str]) -> bool:
t = text.lower()
return any(k in t for k in keywords)
def extract_vitals(case_prompt: str) -> dict:
"""Best-effort extraction of vitals from free-text case prompt."""
vitals = {}
patterns = {
"temperature": r"(?:temp(?:erature)?|T)[:\s]+(\d+\.?\d*)\s*(?:°?[CF])?",
"heart_rate": r"(?:HR|heart rate|pulse)[:\s]+(\d+)",
"blood_pressure": r"(?:BP|blood pressure)[:\s]+(\d+/\d+)",
"respiratory_rate": r"(?:RR|resp(?:iratory)? rate)[:\s]+(\d+)",
"o2_sat": r"(?:SpO2|O2 sat(?:uration)?|sats?)[:\s]+(\d+)\s*%?",
}
for key, pattern in patterns.items():
m = re.search(pattern, case_prompt, re.IGNORECASE)
if m:
vitals[key] = m.group(1)
if not vitals:
vitals["note"] = "See HPI for vitals"
return vitals
def extract_labs(case_prompt: str) -> dict:
"""Pull out any lab values mentioned in the prompt."""
labs = {}
patterns = {
"sodium": r"(?:Na|sodium)[:\s]+(\d+)",
"potassium": r"(?:K|potassium)[:\s]+(\d+\.?\d*)",
"bicarbonate": r"(?:HCO3|bicarb(?:onate)?)[:\s]+(\d+)",
"chloride": r"(?:Cl|chloride)[:\s]+(\d+)",
"creatinine": r"(?:Cr|creatinine)[:\s]+(\d+\.?\d*)",
"glucose": r"(?:glucose|gluc)[:\s]+(\d+)",
"bun": r"BUN[:\s]+(\d+)",
"troponin": r"troponin[:\s]+(\d+\.?\d*)",
"wbc": r"WBC[:\s]+(\d+\.?\d*)",
"hemoglobin": r"(?:Hgb|Hb|hemoglobin)[:\s]+(\d+\.?\d*)",
"platelets": r"(?:PLT|platelets)[:\s]+(\d+)",
"ph": r"\bpH[:\s]+(\d+\.?\d*)",
"pco2": r"pCO2[:\s]+(\d+)",
"po2": r"pO2[:\s]+(\d+)",
"lactate": r"lactate[:\s]+(\d+\.?\d*)",
}
for key, pattern in patterns.items():
m = re.search(pattern, case_prompt, re.IGNORECASE)
if m:
labs[key] = m.group(1)
return labs
def extract_imaging(case_prompt: str) -> dict:
"""Pull imaging findings mentioned in the prompt."""
imaging = {}
modalities = {
"CXR": r"(?:CXR|chest x.ray|chest radiograph)[^.]*\.",
"CT_CHEST": r"CT\s+(?:of\s+)?(?:the\s+)?chest[^.]*\.",
"CT_HEAD": r"CT\s+(?:of\s+)?(?:the\s+)?head[^.]*\.",
"EKG": r"(?:EKG|ECG|electrocardiogram)[^.]*\.",
"ECHO": r"(?:echo(?:cardiogram)?|TTE|TEE)[^.]*\.",
}
for modality, pattern in modalities.items():
m = re.search(pattern, case_prompt, re.IGNORECASE)
if m:
imaging[modality] = m.group(0).strip()
return imaging
def required_calcs_for_diagnosis(diagnosis: str, labs: dict) -> list[dict]:
"""Return a list of calculation tasks appropriate for this diagnosis."""
dx = diagnosis.lower()
calcs = []
if any(k in dx for k in ["dka", "ketoacidosis"]):
na = float(labs.get("sodium", 135))
cl = float(labs.get("chloride", 98))
bicarb = float(labs.get("bicarbonate", 24))
calcs.append({
"formula": "anion_gap",
"inputs": {"na": na, "cl": cl, "bicarb": bicarb},
"expected": na - cl - bicarb,
"tolerance_pct": 5,
})
glucose = float(labs.get("glucose", 100))
calcs.append({
"formula": "corrected_sodium",
"inputs": {"na": na, "glucose": glucose},
"expected": round(na + 0.016 * (glucose - 100), 1),
"tolerance_pct": 5,
})
elif any(k in dx for k in ["aki", "kidney", "renal"]):
ucr = float(labs.get("urine_creatinine", 120))
pcr = float(labs.get("creatinine", 1.5))
una = float(labs.get("urine_sodium", 20))
pna = float(labs.get("sodium", 138))
calcs.append({
"formula": "fena",
"inputs": {"ucr": ucr, "pcr": pcr, "una": una, "pna": pna},
"expected": round((una * pcr) / (pna * ucr) * 100, 2),
"tolerance_pct": 10,
})
elif any(k in dx for k in ["pulmonary embolism", " pe"]):
calcs.append({
"formula": "wells_pe",
"inputs": {"dvt_symptoms": False, "alt_dx_less_likely": True, "hr_gt_100": True,
"immobilisation": False, "prior_dvt_pe": False, "hemoptysis": False, "malignancy": False},
"expected": 4.5,
"tolerance_pct": 20,
})
elif any(k in dx for k in ["nstemi", "myocardial infarction", "acs"]):
calcs.append({
"formula": "timi_nstemi",
"inputs": {"age_gte_65": True, "gte_3_cad_risk_factors": True, "prior_stenosis_50": True,
"st_deviation": True, "gte_2_anginal_events_24h": True,
"aspirin_use_7d": False, "elevated_cardiac_markers": True},
"expected": 6,
"tolerance_pct": 0,
})
elif any(k in dx for k in ["sepsis", "infection"]):
rr = int(labs.get("respiratory_rate", 22))
sbp_raw = labs.get("systolic_bp", "120/80")
try:
sbp = int(str(sbp_raw).split("/")[0])
except Exception:
sbp = 120
calcs.append({
"formula": "qsofa",
"inputs": {"rr": rr, "sbp": sbp, "ams": False},
"expected": int(rr >= 22) + int(sbp <= 100),
"tolerance_pct": 0,
})
if not calcs:
# Fallback: anion gap
na = float(labs.get("sodium", 138))
cl = float(labs.get("chloride", 102))
bicarb = float(labs.get("bicarbonate", 24))
calcs.append({
"formula": "anion_gap",
"inputs": {"na": na, "cl": cl, "bicarb": bicarb},
"expected": na - cl - bicarb,
"tolerance_pct": 5,
})
return calcs
# ---------------------------------------------------------------------------
# Unnecessary tests (generic list; per-case may add more)
# ---------------------------------------------------------------------------
GENERIC_UNNECESSARY = ["chart.labs.LIPIDS", "chart.labs.THYROID", "chart.labs.IRON"]
# ---------------------------------------------------------------------------
# Main data generation
# ---------------------------------------------------------------------------
def build_case(
case_id: str,
task: str,
mcr_row: dict,
medec_row: dict,
) -> dict:
"""Build a single case JSON object from MedCaseReasoning + MEDEC rows."""
case_prompt: str = mcr_row.get("case_prompt", "")
diagnosis: str = mcr_row.get("final_diagnosis", "Unknown")
vitals = extract_vitals(case_prompt)
labs = extract_labs(case_prompt)
imaging = extract_imaging(case_prompt)
# History sections
history = {
"hpi": case_prompt[:1200].strip(), # cap for token efficiency
"pmh": "See case prompt",
"medications": "See case prompt",
"allergies": "NKDA (unless specified)",
"social_history": "See case prompt",
}
calcs = required_calcs_for_diagnosis(diagnosis, labs)
# MEDEC note error
sentences_text: str = medec_row.get("Sentences", "")
error_id: str = str(medec_row.get("Error Sentence ID", "-1"))
error_sentence: str = medec_row.get("Error Sentence", "") or ""
corrected_sentence: str = medec_row.get("Corrected Sentence", "") or ""
error_type: str = medec_row.get("Error Type", "NA") or "NA"
note_error = {
"sentence_id": error_id,
"error_sentence": error_sentence,
"correction": corrected_sentence,
"error_type": error_type,
}
# Diagnosis synonyms from ICD map
dx_lower = diagnosis.lower().strip()
synonyms = [k for k, v in ICD10_SYNONYMS.items() if k != dx_lower and v == ICD10_SYNONYMS.get(dx_lower, "__none__")]
icd = ICD10_SYNONYMS.get(dx_lower, "Z99.9")
chief_complaint = case_prompt.split(".")[0][:200].strip() if case_prompt else "Chief complaint unavailable"
return {
"id": case_id,
"task": task,
"chief_complaint": chief_complaint,
"history": history,
"vitals": vitals,
"labs": labs,
"imaging": imaging,
"correct_diagnosis": diagnosis,
"icd10": icd,
"diagnosis_synonyms": synonyms,
"required_calculations": calcs,
"necessary_tests": ["chart.history", "chart.vitals", "chart.labs.BMP"],
"unnecessary_tests": GENERIC_UNNECESSARY,
"soap_note": build_soap_note(sentences_text),
"note_error": note_error,
"correct_plan": "See diagnostic_reasoning in dataset",
"teaching_point": mcr_row.get("diagnostic_reasoning", "")[:500],
}
def main() -> None:
print("Loading MedCaseReasoning …")
mcr = load_dataset("zou-lab/MedCaseReasoning")
mcr_train = list(mcr["train"])
mcr_val = list(mcr["val"])
all_mcr = mcr_train + mcr_val
print("Loading MEDEC-MS …")
medec = load_dataset("sauravlmx/MEDEC-MS")
medec_train = list(medec["train"])
# Filter MEDEC rows that have real errors (not "-1")
medec_with_errors = [r for r in medec_train if str(r.get("Error Sentence ID", "-1")) != "-1"]
medec_no_errors = [r for r in medec_train if str(r.get("Error Sentence ID", "-1")) == "-1"]
random.shuffle(medec_with_errors)
random.shuffle(medec_no_errors)
def pick_one(target_synonyms: list[str], used: set) -> dict | None:
"""Find best match for one of the synonym strings."""
for synonym in target_synonyms:
for row in all_mcr:
dx = row.get("final_diagnosis", "").lower()
idx = id(row)
if idx not in used and synonym in dx:
used.add(idx)
return row
return None
def pick_mcr_ordered(targets_ordered: list[list[str]], n: int, used: set, fallback_keywords: list[str]) -> list[dict]:
picked = []
for synonyms in targets_ordered:
if len(picked) >= n:
break
row = pick_one(synonyms, used)
if row:
picked.append(row)
# Fill remaining with fallback keyword search
if len(picked) < n:
for row in all_mcr:
if len(picked) >= n:
break
dx = row.get("final_diagnosis", "").lower()
idx = id(row)
if idx not in used and _matches_any(dx, fallback_keywords):
picked.append(row)
used.add(idx)
# Last resort: any unused
for row in all_mcr:
if len(picked) >= n:
break
idx = id(row)
if idx not in used:
picked.append(row)
used.add(idx)
return picked
used_ids: set = set()
easy_rows = pick_mcr_ordered(EASY_TARGETS_ORDERED, 3, used_ids, EASY_TARGETS)
medium_rows = pick_mcr_ordered(MEDIUM_TARGETS_ORDERED, 5, used_ids, MEDIUM_TARGETS)
hard_rows = pick_mcr_ordered(HARD_TARGETS_ORDERED, 7, used_ids, HARD_TARGETS)
task_rows = [
("easy-workup", easy_rows),
("medium-differential", medium_rows),
("hard-deceptive", hard_rows),
]
medec_idx_err = 0
medec_idx_ok = 0
cases: list[dict] = []
for task, rows in task_rows:
for i, mcr_row in enumerate(rows):
case_id = f"{task}-{i+1}"
# Hard cases more likely to have note errors; easy: ~50% chance no error
if task == "easy-workup" and i % 2 == 1:
medec_row = medec_no_errors[medec_idx_ok % len(medec_no_errors)]
medec_idx_ok += 1
else:
medec_row = medec_with_errors[medec_idx_err % len(medec_with_errors)]
medec_idx_err += 1
case = build_case(case_id, task, mcr_row, medec_row)
cases.append(case)
print(f" ✓ {case_id}: {case['correct_diagnosis'][:60]}")
out_cases = OUTPUT_DIR / "cases.json"
out_cases.write_text(json.dumps(cases, indent=2, ensure_ascii=False))
print(f"\nWrote {len(cases)} cases → {out_cases}")
out_icd = OUTPUT_DIR / "icd10_synonyms.json"
out_icd.write_text(json.dumps(ICD10_SYNONYMS, indent=2, ensure_ascii=False))
print(f"Wrote ICD-10 synonym map → {out_icd}")
if __name__ == "__main__":
main()