Spaces:
Sleeping
Sleeping
| """ | |
| prepare_data.py — run LOCALLY once to generate data/cases.json. | |
| Downloads: | |
| zou-lab/MedCaseReasoning (train + val splits) | |
| sauravlmx/MEDEC-MS (train split) | |
| Outputs: | |
| data/cases.json — 15 curated cases (COMMIT THIS) | |
| data/icd10_synonyms.json — diagnosis normalisation map (COMMIT THIS) | |
| Usage: | |
| pip install datasets | |
| python prepare_data.py | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import random | |
| import re | |
| from pathlib import Path | |
| from datasets import load_dataset | |
| SEED = 42 | |
| OUTPUT_DIR = Path(__file__).parent / "data" | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| random.seed(SEED) | |
| # --------------------------------------------------------------------------- | |
| # ICD-10 synonym map — extended | |
| # --------------------------------------------------------------------------- | |
| ICD10_SYNONYMS: dict[str, str] = { | |
| # Pneumonia | |
| "community-acquired pneumonia": "J18.9", | |
| "community acquired pneumonia": "J18.9", | |
| "cap": "J18.9", | |
| "pneumonia": "J18.9", | |
| "lobar pneumonia": "J18.1", | |
| "bacterial pneumonia": "J15.9", | |
| "streptococcal pneumonia": "J13", | |
| "pneumococcal pneumonia": "J13", | |
| # UTI / sepsis | |
| "urinary tract infection": "N39.0", | |
| "uti": "N39.0", | |
| "urosepsis": "A41.51", | |
| "sepsis due to uti": "A41.51", | |
| "gram-negative sepsis": "A41.50", | |
| "septicemia": "A41.9", | |
| "sepsis": "A41.9", | |
| # NSTEMI | |
| "nstemi": "I21.4", | |
| "non-st elevation myocardial infarction": "I21.4", | |
| "non st elevation mi": "I21.4", | |
| "acute coronary syndrome": "I24.9", | |
| "acs": "I24.9", | |
| "unstable angina": "I20.0", | |
| # PE | |
| "pulmonary embolism": "I26.99", | |
| "pe": "I26.99", | |
| "acute pe": "I26.99", | |
| "pulmonary thromboembolism": "I26.99", | |
| # AKI | |
| "acute kidney injury": "N17.9", | |
| "aki": "N17.9", | |
| "acute renal failure": "N17.9", | |
| "pre-renal aki": "N17.0", | |
| "prerenal aki": "N17.0", | |
| "prerenal azotemia": "N17.0", | |
| "atn": "N17.1", | |
| "acute tubular necrosis": "N17.1", | |
| # Hyponatremia | |
| "hyponatremia": "E87.1", | |
| "siadh": "E22.2", | |
| "syndrome of inappropriate antidiuretic hormone": "E22.2", | |
| # DKA | |
| "diabetic ketoacidosis": "E11.10", | |
| "dka": "E11.10", | |
| "euglycemic dka": "E13.10", | |
| "euglycemic diabetic ketoacidosis": "E13.10", | |
| "edka": "E13.10", | |
| # CO poisoning | |
| "carbon monoxide poisoning": "T58.01XA", | |
| "co poisoning": "T58.01XA", | |
| "carbon monoxide toxicity": "T58.01XA", | |
| # HIT | |
| "heparin-induced thrombocytopenia": "T45.515A", | |
| "hit": "T45.515A", | |
| "hit type 2": "T45.515A", | |
| "heparin induced thrombocytopenia": "T45.515A", | |
| # Aortic dissection | |
| "aortic dissection": "I71.00", | |
| "type a aortic dissection": "I71.01", | |
| "type b aortic dissection": "I71.03", | |
| # Heart failure | |
| "heart failure": "I50.9", | |
| "acute decompensated heart failure": "I50.9", | |
| "adhf": "I50.9", | |
| "congestive heart failure": "I50.9", | |
| "chf": "I50.9", | |
| # Hypertensive emergency | |
| "hypertensive emergency": "I16.1", | |
| "hypertensive urgency": "I16.0", | |
| "hypertensive crisis": "I16.9", | |
| # Addison's | |
| "adrenal insufficiency": "E27.1", | |
| "addisonian crisis": "E27.2", | |
| "addison's disease": "E27.1", | |
| "primary adrenal insufficiency": "E27.1", | |
| # Meningitis | |
| "bacterial meningitis": "G00.9", | |
| "meningitis": "G03.9", | |
| "viral meningitis": "G02", | |
| "cryptococcal meningitis": "B45.1", | |
| # Thyroid storm | |
| "thyroid storm": "E05.51", | |
| "thyrotoxic crisis": "E05.51", | |
| "thyrotoxicosis": "E05.90", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # MEDEC error extraction helpers | |
| # --------------------------------------------------------------------------- | |
| def parse_medec_sentences(sentences_text: str) -> list[tuple[str, str]]: | |
| """Returns list of (sentence_id, sentence_text).""" | |
| result = [] | |
| for line in sentences_text.strip().splitlines(): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| parts = line.split(maxsplit=1) | |
| if len(parts) == 2: | |
| result.append((parts[0], parts[1])) | |
| return result | |
| def build_soap_note(sentences_text: str) -> str: | |
| """Return the sentences as a clean numbered SOAP note.""" | |
| return sentences_text.strip() | |
| # --------------------------------------------------------------------------- | |
| # MedCaseReasoning case templates | |
| # --------------------------------------------------------------------------- | |
| # We hand-select 15 representative cases by matching final_diagnosis keywords. | |
| # Priority targets per task: | |
| # Ordered lists of specific target diagnoses (most preferred first) | |
| EASY_TARGETS_ORDERED = [ | |
| ["community-acquired pneumonia", "community acquired pneumonia"], | |
| ["myocardial infarction", "nstemi", "stemi", "acute mi"], | |
| ["sepsis", "urosepsis", "bacteremia"], | |
| ] | |
| MEDIUM_TARGETS_ORDERED = [ | |
| ["pulmonary embolism"], | |
| ["acute kidney injury", "aki", "renal failure"], | |
| ["hyponatremia"], | |
| ["heart failure", "cardiac failure", "congestive"], | |
| ["hypertensive emergency", "hypertensive urgency", "hypertensive crisis"], | |
| ] | |
| HARD_TARGETS_ORDERED = [ | |
| ["euglycemic diabetic ketoacidosis", "euglycemic dka", "edka"], | |
| ["carbon monoxide"], | |
| ["heparin-induced thrombocytopenia", "hit type"], | |
| ["aortic dissection"], | |
| ["adrenal insufficiency", "addisonian crisis", "addison"], | |
| ["meningitis"], | |
| ["diabetic ketoacidosis", "dka"], # fallback for hard | |
| ] | |
| # Flat versions for _matches_any fallback | |
| EASY_TARGETS = ["pneumonia", "myocardial infarction", "nstemi", "sepsis"] | |
| MEDIUM_TARGETS = ["pulmonary embolism", "acute kidney", "hyponatremia", "heart failure", "hypertensive"] | |
| HARD_TARGETS = ["diabetic ketoacidosis", "carbon monoxide", "heparin", "aortic dissection", "adrenal", "meningitis"] | |
| def _matches_any(text: str, keywords: list[str]) -> bool: | |
| t = text.lower() | |
| return any(k in t for k in keywords) | |
| def extract_vitals(case_prompt: str) -> dict: | |
| """Best-effort extraction of vitals from free-text case prompt.""" | |
| vitals = {} | |
| patterns = { | |
| "temperature": r"(?:temp(?:erature)?|T)[:\s]+(\d+\.?\d*)\s*(?:°?[CF])?", | |
| "heart_rate": r"(?:HR|heart rate|pulse)[:\s]+(\d+)", | |
| "blood_pressure": r"(?:BP|blood pressure)[:\s]+(\d+/\d+)", | |
| "respiratory_rate": r"(?:RR|resp(?:iratory)? rate)[:\s]+(\d+)", | |
| "o2_sat": r"(?:SpO2|O2 sat(?:uration)?|sats?)[:\s]+(\d+)\s*%?", | |
| } | |
| for key, pattern in patterns.items(): | |
| m = re.search(pattern, case_prompt, re.IGNORECASE) | |
| if m: | |
| vitals[key] = m.group(1) | |
| if not vitals: | |
| vitals["note"] = "See HPI for vitals" | |
| return vitals | |
| def extract_labs(case_prompt: str) -> dict: | |
| """Pull out any lab values mentioned in the prompt.""" | |
| labs = {} | |
| patterns = { | |
| "sodium": r"(?:Na|sodium)[:\s]+(\d+)", | |
| "potassium": r"(?:K|potassium)[:\s]+(\d+\.?\d*)", | |
| "bicarbonate": r"(?:HCO3|bicarb(?:onate)?)[:\s]+(\d+)", | |
| "chloride": r"(?:Cl|chloride)[:\s]+(\d+)", | |
| "creatinine": r"(?:Cr|creatinine)[:\s]+(\d+\.?\d*)", | |
| "glucose": r"(?:glucose|gluc)[:\s]+(\d+)", | |
| "bun": r"BUN[:\s]+(\d+)", | |
| "troponin": r"troponin[:\s]+(\d+\.?\d*)", | |
| "wbc": r"WBC[:\s]+(\d+\.?\d*)", | |
| "hemoglobin": r"(?:Hgb|Hb|hemoglobin)[:\s]+(\d+\.?\d*)", | |
| "platelets": r"(?:PLT|platelets)[:\s]+(\d+)", | |
| "ph": r"\bpH[:\s]+(\d+\.?\d*)", | |
| "pco2": r"pCO2[:\s]+(\d+)", | |
| "po2": r"pO2[:\s]+(\d+)", | |
| "lactate": r"lactate[:\s]+(\d+\.?\d*)", | |
| } | |
| for key, pattern in patterns.items(): | |
| m = re.search(pattern, case_prompt, re.IGNORECASE) | |
| if m: | |
| labs[key] = m.group(1) | |
| return labs | |
| def extract_imaging(case_prompt: str) -> dict: | |
| """Pull imaging findings mentioned in the prompt.""" | |
| imaging = {} | |
| modalities = { | |
| "CXR": r"(?:CXR|chest x.ray|chest radiograph)[^.]*\.", | |
| "CT_CHEST": r"CT\s+(?:of\s+)?(?:the\s+)?chest[^.]*\.", | |
| "CT_HEAD": r"CT\s+(?:of\s+)?(?:the\s+)?head[^.]*\.", | |
| "EKG": r"(?:EKG|ECG|electrocardiogram)[^.]*\.", | |
| "ECHO": r"(?:echo(?:cardiogram)?|TTE|TEE)[^.]*\.", | |
| } | |
| for modality, pattern in modalities.items(): | |
| m = re.search(pattern, case_prompt, re.IGNORECASE) | |
| if m: | |
| imaging[modality] = m.group(0).strip() | |
| return imaging | |
| def required_calcs_for_diagnosis(diagnosis: str, labs: dict) -> list[dict]: | |
| """Return a list of calculation tasks appropriate for this diagnosis.""" | |
| dx = diagnosis.lower() | |
| calcs = [] | |
| if any(k in dx for k in ["dka", "ketoacidosis"]): | |
| na = float(labs.get("sodium", 135)) | |
| cl = float(labs.get("chloride", 98)) | |
| bicarb = float(labs.get("bicarbonate", 24)) | |
| calcs.append({ | |
| "formula": "anion_gap", | |
| "inputs": {"na": na, "cl": cl, "bicarb": bicarb}, | |
| "expected": na - cl - bicarb, | |
| "tolerance_pct": 5, | |
| }) | |
| glucose = float(labs.get("glucose", 100)) | |
| calcs.append({ | |
| "formula": "corrected_sodium", | |
| "inputs": {"na": na, "glucose": glucose}, | |
| "expected": round(na + 0.016 * (glucose - 100), 1), | |
| "tolerance_pct": 5, | |
| }) | |
| elif any(k in dx for k in ["aki", "kidney", "renal"]): | |
| ucr = float(labs.get("urine_creatinine", 120)) | |
| pcr = float(labs.get("creatinine", 1.5)) | |
| una = float(labs.get("urine_sodium", 20)) | |
| pna = float(labs.get("sodium", 138)) | |
| calcs.append({ | |
| "formula": "fena", | |
| "inputs": {"ucr": ucr, "pcr": pcr, "una": una, "pna": pna}, | |
| "expected": round((una * pcr) / (pna * ucr) * 100, 2), | |
| "tolerance_pct": 10, | |
| }) | |
| elif any(k in dx for k in ["pulmonary embolism", " pe"]): | |
| calcs.append({ | |
| "formula": "wells_pe", | |
| "inputs": {"dvt_symptoms": False, "alt_dx_less_likely": True, "hr_gt_100": True, | |
| "immobilisation": False, "prior_dvt_pe": False, "hemoptysis": False, "malignancy": False}, | |
| "expected": 4.5, | |
| "tolerance_pct": 20, | |
| }) | |
| elif any(k in dx for k in ["nstemi", "myocardial infarction", "acs"]): | |
| calcs.append({ | |
| "formula": "timi_nstemi", | |
| "inputs": {"age_gte_65": True, "gte_3_cad_risk_factors": True, "prior_stenosis_50": True, | |
| "st_deviation": True, "gte_2_anginal_events_24h": True, | |
| "aspirin_use_7d": False, "elevated_cardiac_markers": True}, | |
| "expected": 6, | |
| "tolerance_pct": 0, | |
| }) | |
| elif any(k in dx for k in ["sepsis", "infection"]): | |
| rr = int(labs.get("respiratory_rate", 22)) | |
| sbp_raw = labs.get("systolic_bp", "120/80") | |
| try: | |
| sbp = int(str(sbp_raw).split("/")[0]) | |
| except Exception: | |
| sbp = 120 | |
| calcs.append({ | |
| "formula": "qsofa", | |
| "inputs": {"rr": rr, "sbp": sbp, "ams": False}, | |
| "expected": int(rr >= 22) + int(sbp <= 100), | |
| "tolerance_pct": 0, | |
| }) | |
| if not calcs: | |
| # Fallback: anion gap | |
| na = float(labs.get("sodium", 138)) | |
| cl = float(labs.get("chloride", 102)) | |
| bicarb = float(labs.get("bicarbonate", 24)) | |
| calcs.append({ | |
| "formula": "anion_gap", | |
| "inputs": {"na": na, "cl": cl, "bicarb": bicarb}, | |
| "expected": na - cl - bicarb, | |
| "tolerance_pct": 5, | |
| }) | |
| return calcs | |
| # --------------------------------------------------------------------------- | |
| # Unnecessary tests (generic list; per-case may add more) | |
| # --------------------------------------------------------------------------- | |
| GENERIC_UNNECESSARY = ["chart.labs.LIPIDS", "chart.labs.THYROID", "chart.labs.IRON"] | |
| # --------------------------------------------------------------------------- | |
| # Main data generation | |
| # --------------------------------------------------------------------------- | |
| def build_case( | |
| case_id: str, | |
| task: str, | |
| mcr_row: dict, | |
| medec_row: dict, | |
| ) -> dict: | |
| """Build a single case JSON object from MedCaseReasoning + MEDEC rows.""" | |
| case_prompt: str = mcr_row.get("case_prompt", "") | |
| diagnosis: str = mcr_row.get("final_diagnosis", "Unknown") | |
| vitals = extract_vitals(case_prompt) | |
| labs = extract_labs(case_prompt) | |
| imaging = extract_imaging(case_prompt) | |
| # History sections | |
| history = { | |
| "hpi": case_prompt[:1200].strip(), # cap for token efficiency | |
| "pmh": "See case prompt", | |
| "medications": "See case prompt", | |
| "allergies": "NKDA (unless specified)", | |
| "social_history": "See case prompt", | |
| } | |
| calcs = required_calcs_for_diagnosis(diagnosis, labs) | |
| # MEDEC note error | |
| sentences_text: str = medec_row.get("Sentences", "") | |
| error_id: str = str(medec_row.get("Error Sentence ID", "-1")) | |
| error_sentence: str = medec_row.get("Error Sentence", "") or "" | |
| corrected_sentence: str = medec_row.get("Corrected Sentence", "") or "" | |
| error_type: str = medec_row.get("Error Type", "NA") or "NA" | |
| note_error = { | |
| "sentence_id": error_id, | |
| "error_sentence": error_sentence, | |
| "correction": corrected_sentence, | |
| "error_type": error_type, | |
| } | |
| # Diagnosis synonyms from ICD map | |
| dx_lower = diagnosis.lower().strip() | |
| synonyms = [k for k, v in ICD10_SYNONYMS.items() if k != dx_lower and v == ICD10_SYNONYMS.get(dx_lower, "__none__")] | |
| icd = ICD10_SYNONYMS.get(dx_lower, "Z99.9") | |
| chief_complaint = case_prompt.split(".")[0][:200].strip() if case_prompt else "Chief complaint unavailable" | |
| return { | |
| "id": case_id, | |
| "task": task, | |
| "chief_complaint": chief_complaint, | |
| "history": history, | |
| "vitals": vitals, | |
| "labs": labs, | |
| "imaging": imaging, | |
| "correct_diagnosis": diagnosis, | |
| "icd10": icd, | |
| "diagnosis_synonyms": synonyms, | |
| "required_calculations": calcs, | |
| "necessary_tests": ["chart.history", "chart.vitals", "chart.labs.BMP"], | |
| "unnecessary_tests": GENERIC_UNNECESSARY, | |
| "soap_note": build_soap_note(sentences_text), | |
| "note_error": note_error, | |
| "correct_plan": "See diagnostic_reasoning in dataset", | |
| "teaching_point": mcr_row.get("diagnostic_reasoning", "")[:500], | |
| } | |
| def main() -> None: | |
| print("Loading MedCaseReasoning …") | |
| mcr = load_dataset("zou-lab/MedCaseReasoning") | |
| mcr_train = list(mcr["train"]) | |
| mcr_val = list(mcr["val"]) | |
| all_mcr = mcr_train + mcr_val | |
| print("Loading MEDEC-MS …") | |
| medec = load_dataset("sauravlmx/MEDEC-MS") | |
| medec_train = list(medec["train"]) | |
| # Filter MEDEC rows that have real errors (not "-1") | |
| medec_with_errors = [r for r in medec_train if str(r.get("Error Sentence ID", "-1")) != "-1"] | |
| medec_no_errors = [r for r in medec_train if str(r.get("Error Sentence ID", "-1")) == "-1"] | |
| random.shuffle(medec_with_errors) | |
| random.shuffle(medec_no_errors) | |
| def pick_one(target_synonyms: list[str], used: set) -> dict | None: | |
| """Find best match for one of the synonym strings.""" | |
| for synonym in target_synonyms: | |
| for row in all_mcr: | |
| dx = row.get("final_diagnosis", "").lower() | |
| idx = id(row) | |
| if idx not in used and synonym in dx: | |
| used.add(idx) | |
| return row | |
| return None | |
| def pick_mcr_ordered(targets_ordered: list[list[str]], n: int, used: set, fallback_keywords: list[str]) -> list[dict]: | |
| picked = [] | |
| for synonyms in targets_ordered: | |
| if len(picked) >= n: | |
| break | |
| row = pick_one(synonyms, used) | |
| if row: | |
| picked.append(row) | |
| # Fill remaining with fallback keyword search | |
| if len(picked) < n: | |
| for row in all_mcr: | |
| if len(picked) >= n: | |
| break | |
| dx = row.get("final_diagnosis", "").lower() | |
| idx = id(row) | |
| if idx not in used and _matches_any(dx, fallback_keywords): | |
| picked.append(row) | |
| used.add(idx) | |
| # Last resort: any unused | |
| for row in all_mcr: | |
| if len(picked) >= n: | |
| break | |
| idx = id(row) | |
| if idx not in used: | |
| picked.append(row) | |
| used.add(idx) | |
| return picked | |
| used_ids: set = set() | |
| easy_rows = pick_mcr_ordered(EASY_TARGETS_ORDERED, 3, used_ids, EASY_TARGETS) | |
| medium_rows = pick_mcr_ordered(MEDIUM_TARGETS_ORDERED, 5, used_ids, MEDIUM_TARGETS) | |
| hard_rows = pick_mcr_ordered(HARD_TARGETS_ORDERED, 7, used_ids, HARD_TARGETS) | |
| task_rows = [ | |
| ("easy-workup", easy_rows), | |
| ("medium-differential", medium_rows), | |
| ("hard-deceptive", hard_rows), | |
| ] | |
| medec_idx_err = 0 | |
| medec_idx_ok = 0 | |
| cases: list[dict] = [] | |
| for task, rows in task_rows: | |
| for i, mcr_row in enumerate(rows): | |
| case_id = f"{task}-{i+1}" | |
| # Hard cases more likely to have note errors; easy: ~50% chance no error | |
| if task == "easy-workup" and i % 2 == 1: | |
| medec_row = medec_no_errors[medec_idx_ok % len(medec_no_errors)] | |
| medec_idx_ok += 1 | |
| else: | |
| medec_row = medec_with_errors[medec_idx_err % len(medec_with_errors)] | |
| medec_idx_err += 1 | |
| case = build_case(case_id, task, mcr_row, medec_row) | |
| cases.append(case) | |
| print(f" ✓ {case_id}: {case['correct_diagnosis'][:60]}") | |
| out_cases = OUTPUT_DIR / "cases.json" | |
| out_cases.write_text(json.dumps(cases, indent=2, ensure_ascii=False)) | |
| print(f"\nWrote {len(cases)} cases → {out_cases}") | |
| out_icd = OUTPUT_DIR / "icd10_synonyms.json" | |
| out_icd.write_text(json.dumps(ICD10_SYNONYMS, indent=2, ensure_ascii=False)) | |
| print(f"Wrote ICD-10 synonym map → {out_icd}") | |
| if __name__ == "__main__": | |
| main() | |