""" prepare_data.py — run LOCALLY once to generate data/cases.json. Downloads: zou-lab/MedCaseReasoning (train + val splits) sauravlmx/MEDEC-MS (train split) Outputs: data/cases.json — 15 curated cases (COMMIT THIS) data/icd10_synonyms.json — diagnosis normalisation map (COMMIT THIS) Usage: pip install datasets python prepare_data.py """ from __future__ import annotations import json import random import re from pathlib import Path from datasets import load_dataset SEED = 42 OUTPUT_DIR = Path(__file__).parent / "data" OUTPUT_DIR.mkdir(exist_ok=True) random.seed(SEED) # --------------------------------------------------------------------------- # ICD-10 synonym map — extended # --------------------------------------------------------------------------- ICD10_SYNONYMS: dict[str, str] = { # Pneumonia "community-acquired pneumonia": "J18.9", "community acquired pneumonia": "J18.9", "cap": "J18.9", "pneumonia": "J18.9", "lobar pneumonia": "J18.1", "bacterial pneumonia": "J15.9", "streptococcal pneumonia": "J13", "pneumococcal pneumonia": "J13", # UTI / sepsis "urinary tract infection": "N39.0", "uti": "N39.0", "urosepsis": "A41.51", "sepsis due to uti": "A41.51", "gram-negative sepsis": "A41.50", "septicemia": "A41.9", "sepsis": "A41.9", # NSTEMI "nstemi": "I21.4", "non-st elevation myocardial infarction": "I21.4", "non st elevation mi": "I21.4", "acute coronary syndrome": "I24.9", "acs": "I24.9", "unstable angina": "I20.0", # PE "pulmonary embolism": "I26.99", "pe": "I26.99", "acute pe": "I26.99", "pulmonary thromboembolism": "I26.99", # AKI "acute kidney injury": "N17.9", "aki": "N17.9", "acute renal failure": "N17.9", "pre-renal aki": "N17.0", "prerenal aki": "N17.0", "prerenal azotemia": "N17.0", "atn": "N17.1", "acute tubular necrosis": "N17.1", # Hyponatremia "hyponatremia": "E87.1", "siadh": "E22.2", "syndrome of inappropriate antidiuretic hormone": "E22.2", # DKA "diabetic ketoacidosis": "E11.10", "dka": "E11.10", "euglycemic dka": "E13.10", "euglycemic diabetic ketoacidosis": "E13.10", "edka": "E13.10", # CO poisoning "carbon monoxide poisoning": "T58.01XA", "co poisoning": "T58.01XA", "carbon monoxide toxicity": "T58.01XA", # HIT "heparin-induced thrombocytopenia": "T45.515A", "hit": "T45.515A", "hit type 2": "T45.515A", "heparin induced thrombocytopenia": "T45.515A", # Aortic dissection "aortic dissection": "I71.00", "type a aortic dissection": "I71.01", "type b aortic dissection": "I71.03", # Heart failure "heart failure": "I50.9", "acute decompensated heart failure": "I50.9", "adhf": "I50.9", "congestive heart failure": "I50.9", "chf": "I50.9", # Hypertensive emergency "hypertensive emergency": "I16.1", "hypertensive urgency": "I16.0", "hypertensive crisis": "I16.9", # Addison's "adrenal insufficiency": "E27.1", "addisonian crisis": "E27.2", "addison's disease": "E27.1", "primary adrenal insufficiency": "E27.1", # Meningitis "bacterial meningitis": "G00.9", "meningitis": "G03.9", "viral meningitis": "G02", "cryptococcal meningitis": "B45.1", # Thyroid storm "thyroid storm": "E05.51", "thyrotoxic crisis": "E05.51", "thyrotoxicosis": "E05.90", } # --------------------------------------------------------------------------- # MEDEC error extraction helpers # --------------------------------------------------------------------------- def parse_medec_sentences(sentences_text: str) -> list[tuple[str, str]]: """Returns list of (sentence_id, sentence_text).""" result = [] for line in sentences_text.strip().splitlines(): line = line.strip() if not line: continue parts = line.split(maxsplit=1) if len(parts) == 2: result.append((parts[0], parts[1])) return result def build_soap_note(sentences_text: str) -> str: """Return the sentences as a clean numbered SOAP note.""" return sentences_text.strip() # --------------------------------------------------------------------------- # MedCaseReasoning case templates # --------------------------------------------------------------------------- # We hand-select 15 representative cases by matching final_diagnosis keywords. # Priority targets per task: # Ordered lists of specific target diagnoses (most preferred first) EASY_TARGETS_ORDERED = [ ["community-acquired pneumonia", "community acquired pneumonia"], ["myocardial infarction", "nstemi", "stemi", "acute mi"], ["sepsis", "urosepsis", "bacteremia"], ] MEDIUM_TARGETS_ORDERED = [ ["pulmonary embolism"], ["acute kidney injury", "aki", "renal failure"], ["hyponatremia"], ["heart failure", "cardiac failure", "congestive"], ["hypertensive emergency", "hypertensive urgency", "hypertensive crisis"], ] HARD_TARGETS_ORDERED = [ ["euglycemic diabetic ketoacidosis", "euglycemic dka", "edka"], ["carbon monoxide"], ["heparin-induced thrombocytopenia", "hit type"], ["aortic dissection"], ["adrenal insufficiency", "addisonian crisis", "addison"], ["meningitis"], ["diabetic ketoacidosis", "dka"], # fallback for hard ] # Flat versions for _matches_any fallback EASY_TARGETS = ["pneumonia", "myocardial infarction", "nstemi", "sepsis"] MEDIUM_TARGETS = ["pulmonary embolism", "acute kidney", "hyponatremia", "heart failure", "hypertensive"] HARD_TARGETS = ["diabetic ketoacidosis", "carbon monoxide", "heparin", "aortic dissection", "adrenal", "meningitis"] def _matches_any(text: str, keywords: list[str]) -> bool: t = text.lower() return any(k in t for k in keywords) def extract_vitals(case_prompt: str) -> dict: """Best-effort extraction of vitals from free-text case prompt.""" vitals = {} patterns = { "temperature": r"(?:temp(?:erature)?|T)[:\s]+(\d+\.?\d*)\s*(?:°?[CF])?", "heart_rate": r"(?:HR|heart rate|pulse)[:\s]+(\d+)", "blood_pressure": r"(?:BP|blood pressure)[:\s]+(\d+/\d+)", "respiratory_rate": r"(?:RR|resp(?:iratory)? rate)[:\s]+(\d+)", "o2_sat": r"(?:SpO2|O2 sat(?:uration)?|sats?)[:\s]+(\d+)\s*%?", } for key, pattern in patterns.items(): m = re.search(pattern, case_prompt, re.IGNORECASE) if m: vitals[key] = m.group(1) if not vitals: vitals["note"] = "See HPI for vitals" return vitals def extract_labs(case_prompt: str) -> dict: """Pull out any lab values mentioned in the prompt.""" labs = {} patterns = { "sodium": r"(?:Na|sodium)[:\s]+(\d+)", "potassium": r"(?:K|potassium)[:\s]+(\d+\.?\d*)", "bicarbonate": r"(?:HCO3|bicarb(?:onate)?)[:\s]+(\d+)", "chloride": r"(?:Cl|chloride)[:\s]+(\d+)", "creatinine": r"(?:Cr|creatinine)[:\s]+(\d+\.?\d*)", "glucose": r"(?:glucose|gluc)[:\s]+(\d+)", "bun": r"BUN[:\s]+(\d+)", "troponin": r"troponin[:\s]+(\d+\.?\d*)", "wbc": r"WBC[:\s]+(\d+\.?\d*)", "hemoglobin": r"(?:Hgb|Hb|hemoglobin)[:\s]+(\d+\.?\d*)", "platelets": r"(?:PLT|platelets)[:\s]+(\d+)", "ph": r"\bpH[:\s]+(\d+\.?\d*)", "pco2": r"pCO2[:\s]+(\d+)", "po2": r"pO2[:\s]+(\d+)", "lactate": r"lactate[:\s]+(\d+\.?\d*)", } for key, pattern in patterns.items(): m = re.search(pattern, case_prompt, re.IGNORECASE) if m: labs[key] = m.group(1) return labs def extract_imaging(case_prompt: str) -> dict: """Pull imaging findings mentioned in the prompt.""" imaging = {} modalities = { "CXR": r"(?:CXR|chest x.ray|chest radiograph)[^.]*\.", "CT_CHEST": r"CT\s+(?:of\s+)?(?:the\s+)?chest[^.]*\.", "CT_HEAD": r"CT\s+(?:of\s+)?(?:the\s+)?head[^.]*\.", "EKG": r"(?:EKG|ECG|electrocardiogram)[^.]*\.", "ECHO": r"(?:echo(?:cardiogram)?|TTE|TEE)[^.]*\.", } for modality, pattern in modalities.items(): m = re.search(pattern, case_prompt, re.IGNORECASE) if m: imaging[modality] = m.group(0).strip() return imaging def required_calcs_for_diagnosis(diagnosis: str, labs: dict) -> list[dict]: """Return a list of calculation tasks appropriate for this diagnosis.""" dx = diagnosis.lower() calcs = [] if any(k in dx for k in ["dka", "ketoacidosis"]): na = float(labs.get("sodium", 135)) cl = float(labs.get("chloride", 98)) bicarb = float(labs.get("bicarbonate", 24)) calcs.append({ "formula": "anion_gap", "inputs": {"na": na, "cl": cl, "bicarb": bicarb}, "expected": na - cl - bicarb, "tolerance_pct": 5, }) glucose = float(labs.get("glucose", 100)) calcs.append({ "formula": "corrected_sodium", "inputs": {"na": na, "glucose": glucose}, "expected": round(na + 0.016 * (glucose - 100), 1), "tolerance_pct": 5, }) elif any(k in dx for k in ["aki", "kidney", "renal"]): ucr = float(labs.get("urine_creatinine", 120)) pcr = float(labs.get("creatinine", 1.5)) una = float(labs.get("urine_sodium", 20)) pna = float(labs.get("sodium", 138)) calcs.append({ "formula": "fena", "inputs": {"ucr": ucr, "pcr": pcr, "una": una, "pna": pna}, "expected": round((una * pcr) / (pna * ucr) * 100, 2), "tolerance_pct": 10, }) elif any(k in dx for k in ["pulmonary embolism", " pe"]): calcs.append({ "formula": "wells_pe", "inputs": {"dvt_symptoms": False, "alt_dx_less_likely": True, "hr_gt_100": True, "immobilisation": False, "prior_dvt_pe": False, "hemoptysis": False, "malignancy": False}, "expected": 4.5, "tolerance_pct": 20, }) elif any(k in dx for k in ["nstemi", "myocardial infarction", "acs"]): calcs.append({ "formula": "timi_nstemi", "inputs": {"age_gte_65": True, "gte_3_cad_risk_factors": True, "prior_stenosis_50": True, "st_deviation": True, "gte_2_anginal_events_24h": True, "aspirin_use_7d": False, "elevated_cardiac_markers": True}, "expected": 6, "tolerance_pct": 0, }) elif any(k in dx for k in ["sepsis", "infection"]): rr = int(labs.get("respiratory_rate", 22)) sbp_raw = labs.get("systolic_bp", "120/80") try: sbp = int(str(sbp_raw).split("/")[0]) except Exception: sbp = 120 calcs.append({ "formula": "qsofa", "inputs": {"rr": rr, "sbp": sbp, "ams": False}, "expected": int(rr >= 22) + int(sbp <= 100), "tolerance_pct": 0, }) if not calcs: # Fallback: anion gap na = float(labs.get("sodium", 138)) cl = float(labs.get("chloride", 102)) bicarb = float(labs.get("bicarbonate", 24)) calcs.append({ "formula": "anion_gap", "inputs": {"na": na, "cl": cl, "bicarb": bicarb}, "expected": na - cl - bicarb, "tolerance_pct": 5, }) return calcs # --------------------------------------------------------------------------- # Unnecessary tests (generic list; per-case may add more) # --------------------------------------------------------------------------- GENERIC_UNNECESSARY = ["chart.labs.LIPIDS", "chart.labs.THYROID", "chart.labs.IRON"] # --------------------------------------------------------------------------- # Main data generation # --------------------------------------------------------------------------- def build_case( case_id: str, task: str, mcr_row: dict, medec_row: dict, ) -> dict: """Build a single case JSON object from MedCaseReasoning + MEDEC rows.""" case_prompt: str = mcr_row.get("case_prompt", "") diagnosis: str = mcr_row.get("final_diagnosis", "Unknown") vitals = extract_vitals(case_prompt) labs = extract_labs(case_prompt) imaging = extract_imaging(case_prompt) # History sections history = { "hpi": case_prompt[:1200].strip(), # cap for token efficiency "pmh": "See case prompt", "medications": "See case prompt", "allergies": "NKDA (unless specified)", "social_history": "See case prompt", } calcs = required_calcs_for_diagnosis(diagnosis, labs) # MEDEC note error sentences_text: str = medec_row.get("Sentences", "") error_id: str = str(medec_row.get("Error Sentence ID", "-1")) error_sentence: str = medec_row.get("Error Sentence", "") or "" corrected_sentence: str = medec_row.get("Corrected Sentence", "") or "" error_type: str = medec_row.get("Error Type", "NA") or "NA" note_error = { "sentence_id": error_id, "error_sentence": error_sentence, "correction": corrected_sentence, "error_type": error_type, } # Diagnosis synonyms from ICD map dx_lower = diagnosis.lower().strip() synonyms = [k for k, v in ICD10_SYNONYMS.items() if k != dx_lower and v == ICD10_SYNONYMS.get(dx_lower, "__none__")] icd = ICD10_SYNONYMS.get(dx_lower, "Z99.9") chief_complaint = case_prompt.split(".")[0][:200].strip() if case_prompt else "Chief complaint unavailable" return { "id": case_id, "task": task, "chief_complaint": chief_complaint, "history": history, "vitals": vitals, "labs": labs, "imaging": imaging, "correct_diagnosis": diagnosis, "icd10": icd, "diagnosis_synonyms": synonyms, "required_calculations": calcs, "necessary_tests": ["chart.history", "chart.vitals", "chart.labs.BMP"], "unnecessary_tests": GENERIC_UNNECESSARY, "soap_note": build_soap_note(sentences_text), "note_error": note_error, "correct_plan": "See diagnostic_reasoning in dataset", "teaching_point": mcr_row.get("diagnostic_reasoning", "")[:500], } def main() -> None: print("Loading MedCaseReasoning …") mcr = load_dataset("zou-lab/MedCaseReasoning") mcr_train = list(mcr["train"]) mcr_val = list(mcr["val"]) all_mcr = mcr_train + mcr_val print("Loading MEDEC-MS …") medec = load_dataset("sauravlmx/MEDEC-MS") medec_train = list(medec["train"]) # Filter MEDEC rows that have real errors (not "-1") medec_with_errors = [r for r in medec_train if str(r.get("Error Sentence ID", "-1")) != "-1"] medec_no_errors = [r for r in medec_train if str(r.get("Error Sentence ID", "-1")) == "-1"] random.shuffle(medec_with_errors) random.shuffle(medec_no_errors) def pick_one(target_synonyms: list[str], used: set) -> dict | None: """Find best match for one of the synonym strings.""" for synonym in target_synonyms: for row in all_mcr: dx = row.get("final_diagnosis", "").lower() idx = id(row) if idx not in used and synonym in dx: used.add(idx) return row return None def pick_mcr_ordered(targets_ordered: list[list[str]], n: int, used: set, fallback_keywords: list[str]) -> list[dict]: picked = [] for synonyms in targets_ordered: if len(picked) >= n: break row = pick_one(synonyms, used) if row: picked.append(row) # Fill remaining with fallback keyword search if len(picked) < n: for row in all_mcr: if len(picked) >= n: break dx = row.get("final_diagnosis", "").lower() idx = id(row) if idx not in used and _matches_any(dx, fallback_keywords): picked.append(row) used.add(idx) # Last resort: any unused for row in all_mcr: if len(picked) >= n: break idx = id(row) if idx not in used: picked.append(row) used.add(idx) return picked used_ids: set = set() easy_rows = pick_mcr_ordered(EASY_TARGETS_ORDERED, 3, used_ids, EASY_TARGETS) medium_rows = pick_mcr_ordered(MEDIUM_TARGETS_ORDERED, 5, used_ids, MEDIUM_TARGETS) hard_rows = pick_mcr_ordered(HARD_TARGETS_ORDERED, 7, used_ids, HARD_TARGETS) task_rows = [ ("easy-workup", easy_rows), ("medium-differential", medium_rows), ("hard-deceptive", hard_rows), ] medec_idx_err = 0 medec_idx_ok = 0 cases: list[dict] = [] for task, rows in task_rows: for i, mcr_row in enumerate(rows): case_id = f"{task}-{i+1}" # Hard cases more likely to have note errors; easy: ~50% chance no error if task == "easy-workup" and i % 2 == 1: medec_row = medec_no_errors[medec_idx_ok % len(medec_no_errors)] medec_idx_ok += 1 else: medec_row = medec_with_errors[medec_idx_err % len(medec_with_errors)] medec_idx_err += 1 case = build_case(case_id, task, mcr_row, medec_row) cases.append(case) print(f" ✓ {case_id}: {case['correct_diagnosis'][:60]}") out_cases = OUTPUT_DIR / "cases.json" out_cases.write_text(json.dumps(cases, indent=2, ensure_ascii=False)) print(f"\nWrote {len(cases)} cases → {out_cases}") out_icd = OUTPUT_DIR / "icd10_synonyms.json" out_icd.write_text(json.dumps(ICD10_SYNONYMS, indent=2, ensure_ascii=False)) print(f"Wrote ICD-10 synonym map → {out_icd}") if __name__ == "__main__": main()