Spaces:
Sleeping
Sleeping
| """ | |
| NEJM Image Challenge Dataset Loader. | |
| Expects the cx0/nejm-image-challenge dataset structure: | |
| nejm/ | |
| βββ data.json (or nejm_data.json) | |
| β Each entry: {date, image_url, prompt (clinical vignette), | |
| β options [A..E], correct_answer, votes} | |
| βββ images/ (downloaded images, named by date YYYYMMDD.jpg) | |
| βββ parsed_vignettes.json (pre-parsed structured fields, optional) | |
| The clinical vignette is decomposed into 5 requestable text channels | |
| using LLM-based parsing (see scripts/parse_nejm_vignettes.py). | |
| """ | |
| import json | |
| import logging | |
| import random | |
| import re | |
| from pathlib import Path | |
| from .base import DatasetBase, MedicalCase, ChannelData | |
| from api_client import encode_image_to_base64 | |
| import config | |
| logger = logging.getLogger(__name__) | |
| # ---- Vignette parsing schema ---- | |
| VIGNETTE_FIELDS = [ | |
| "demographics", | |
| "chief_complaint", | |
| "medical_history", | |
| "exam_findings", | |
| "investigations", | |
| ] | |
| VIGNETTE_PARSE_PROMPT = """You are a medical data extraction system. Parse the following clinical \ | |
| vignette into exactly 5 structured fields. Extract ONLY information that is explicitly stated. \ | |
| If a field has no relevant information, write "Not mentioned." | |
| FIELDS: | |
| 1. demographics: Patient age, sex, race/ethnicity if stated. | |
| 2. chief_complaint: The primary presenting symptom(s) and their duration. | |
| 3. medical_history: Past medical conditions, medications, surgical history, family history, social history (smoking, alcohol, etc.). | |
| 4. exam_findings: Physical examination findings, vital signs. | |
| 5. investigations: Laboratory results, imaging findings, test results (anything with numbers or test names). | |
| CLINICAL VIGNETTE: | |
| {vignette} | |
| Respond in EXACTLY this JSON format (no markdown, no extra text): | |
| {{"demographics": "...", "chief_complaint": "...", "medical_history": "...", "exam_findings": "...", "investigations": "..."}}""" | |
| class NEJMDataset(DatasetBase): | |
| """Loader for NEJM Image Challenge dataset.""" | |
| def __init__( | |
| self, | |
| data_dir: str | Path = None, | |
| split: str = "test", | |
| vlm_client=None, | |
| use_cached_parse: bool = True, | |
| ): | |
| super().__init__(data_dir or config.DATASET_PATHS["nejm"], split) | |
| self.vlm_client = vlm_client | |
| self.use_cached_parse = use_cached_parse | |
| self._parsed_cache_path = self.data_dir / "parsed_vignettes.json" | |
| def get_name(self) -> str: | |
| return "nejm" | |
| def load(self) -> list[MedicalCase]: | |
| logger.info(f"Loading NEJM dataset from {self.data_dir}") | |
| # ---- Load raw data ---- | |
| raw_data = self._load_raw_data() | |
| if not raw_data: | |
| return [] | |
| logger.info(f"Found {len(raw_data)} NEJM cases") | |
| # ---- Load or create parsed vignettes ---- | |
| parsed = self._load_or_parse_vignettes(raw_data) | |
| # ---- Build cases ---- | |
| self.cases = [] | |
| for entry in raw_data: | |
| case_id = entry.get("date", entry.get("id", "unknown")) | |
| case = self._build_case(entry, parsed.get(case_id, {})) | |
| if case is not None: | |
| self.cases.append(case) | |
| logger.info(f"Loaded {len(self.cases)} NEJM cases") | |
| return self.cases | |
| def _load_raw_data(self) -> list[dict]: | |
| """Load the raw NEJM dataset JSON.""" | |
| for name in ["data.json", "nejm_data.json", "nejm.json", "dataset.json"]: | |
| p = self.data_dir / name | |
| if p.exists(): | |
| with open(p, encoding="utf-8") as f: | |
| data = json.load(f) | |
| if isinstance(data, dict): | |
| # Handle {date: entry} format | |
| return [{"date": k, **v} if isinstance(v, dict) else v | |
| for k, v in data.items()] | |
| return data | |
| # Try loading all JSON files | |
| jsons = list(self.data_dir.glob("*.json")) | |
| if jsons: | |
| with open(jsons[0], encoding="utf-8") as f: | |
| return json.load(f) | |
| logger.error(f"No data file found in {self.data_dir}") | |
| return [] | |
| def _load_or_parse_vignettes(self, raw_data: list[dict]) -> dict: | |
| """Load cached parsed vignettes or parse them with LLM.""" | |
| # Try cache first | |
| if self.use_cached_parse and self._parsed_cache_path.exists(): | |
| logger.info(f"Loading cached vignette parses from {self._parsed_cache_path}") | |
| with open(self._parsed_cache_path) as f: | |
| return json.load(f) | |
| # Parse with LLM if client is available | |
| if self.vlm_client is not None: | |
| logger.info("Parsing vignettes with LLM (this may take a while)...") | |
| parsed = {} | |
| for entry in raw_data: | |
| case_id = entry.get("date", entry.get("id", "unknown")) | |
| vignette = entry.get("question", entry.get("prompt", entry.get("vignette", ""))) | |
| if vignette: | |
| parsed[case_id] = self._parse_vignette_with_llm(vignette) | |
| # Cache results | |
| with open(self._parsed_cache_path, "w") as f: | |
| json.dump(parsed, f, indent=2) | |
| logger.info(f"Cached {len(parsed)} parsed vignettes") | |
| return parsed | |
| # Fallback: rule-based parsing | |
| logger.info("No LLM client available. Using rule-based vignette parsing (less accurate).") | |
| parsed = {} | |
| for entry in raw_data: | |
| case_id = entry.get("date", entry.get("id", "unknown")) | |
| vignette = entry.get("question", entry.get("prompt", entry.get("vignette", ""))) | |
| if vignette: | |
| parsed[case_id] = self._parse_vignette_rules(vignette) | |
| return parsed | |
| def _parse_vignette_with_llm(self, vignette: str) -> dict: | |
| """Parse a single vignette using the LLM API.""" | |
| prompt = VIGNETTE_PARSE_PROMPT.format(vignette=vignette) | |
| try: | |
| response = self.vlm_client.call_with_retry( | |
| system_prompt="You are a medical data extraction system. Respond only with valid JSON.", | |
| user_text=prompt, | |
| images=None, | |
| temperature=0.0, | |
| max_tokens=1024, | |
| ) | |
| # Parse JSON from response | |
| text = response.text.strip() | |
| # Strip markdown code fences if present | |
| text = re.sub(r"^```(?:json)?\s*", "", text) | |
| text = re.sub(r"\s*```$", "", text) | |
| parsed = json.loads(text) | |
| # Validate expected fields | |
| for field in VIGNETTE_FIELDS: | |
| if field not in parsed: | |
| parsed[field] = "Not mentioned." | |
| return parsed | |
| except Exception as e: | |
| logger.warning(f"LLM vignette parsing failed: {e}. Falling back to rules.") | |
| return self._parse_vignette_rules(vignette) | |
| def _parse_vignette_rules(self, vignette: str) -> dict: | |
| """ | |
| Rule-based fallback for vignette parsing. | |
| Uses heuristic sentence classification. | |
| """ | |
| result = {f: "" for f in VIGNETTE_FIELDS} | |
| sentences = re.split(r'(?<=[.!?])\s+', vignette) | |
| # Patterns for classification | |
| demo_pattern = re.compile( | |
| r'\b(\d{1,3})[-\s]year[-\s]old\b|' | |
| r'\b(male|female|man|woman|boy|girl)\b', | |
| re.IGNORECASE, | |
| ) | |
| complaint_pattern = re.compile( | |
| r'\bpresent(?:s|ed|ing)\b|\bcomplain(?:s|ed|ing)\b|\breport(?:s|ed|ing)\b|' | |
| r'\bseek(?:s|ing)\b|\badmitted\b', | |
| re.IGNORECASE, | |
| ) | |
| history_pattern = re.compile( | |
| r'\bhistory\b|\bprevious(?:ly)?\b|\bmedication\b|\btaking\b|\bdiagnosed\b|' | |
| r'\bsmok(?:es|ing|er)\b|\balcohol\b|\bfamily\b|\bsurgery\b', | |
| re.IGNORECASE, | |
| ) | |
| exam_pattern = re.compile( | |
| r'\bexamination\b|\bexam\b|\bpalpat(?:ion|ed)\b|\bauscult(?:ation|ed)\b|' | |
| r'\bvital\b|\bblood\s+pressure\b|\bheart\s+rate\b|\btemperature\b|' | |
| r'\bappears\b|\btender\b|\bswollen\b|\berythema\b', | |
| re.IGNORECASE, | |
| ) | |
| invest_pattern = re.compile( | |
| r'\b(?:hemoglobin|WBC|platelet|creatinine|BUN|glucose|sodium|potassium)\b|' | |
| r'\b(?:CT|MRI|X[-\s]?ray|ultrasound|ECG|EKG|biopsy)\b|' | |
| r'\b\d+\.?\d*\s*(?:mg|g|mL|mmol|mEq|U|IU|mmHg|\/dL|\/L)\b|' | |
| r'\blaboratory\b|\blab(?:s)?\b|\btest\b|\blevel\b|\bfinding\b', | |
| re.IGNORECASE, | |
| ) | |
| for sent in sentences: | |
| sent = sent.strip() | |
| if not sent: | |
| continue | |
| # Demographics: typically the first sentence | |
| if demo_pattern.search(sent) and not result["demographics"]: | |
| result["demographics"] = sent | |
| continue | |
| # Check each pattern (a sentence can match multiple, take first) | |
| matched = False | |
| for field, pattern in [ | |
| ("investigations", invest_pattern), | |
| ("exam_findings", exam_pattern), | |
| ("medical_history", history_pattern), | |
| ("chief_complaint", complaint_pattern), | |
| ]: | |
| if pattern.search(sent): | |
| if result[field]: | |
| result[field] += " " + sent | |
| else: | |
| result[field] = sent | |
| matched = True | |
| break | |
| # Unmatched sentences go to chief_complaint as default | |
| if not matched: | |
| if result["chief_complaint"]: | |
| result["chief_complaint"] += " " + sent | |
| else: | |
| result["chief_complaint"] = sent | |
| # Replace empty fields | |
| for field in VIGNETTE_FIELDS: | |
| if not result[field].strip(): | |
| result[field] = "Not mentioned." | |
| return result | |
| def _date_to_yyyymmdd(date_str: str) -> str | None: | |
| """Convert 'apr-01-2010' style date to '20100401' for image lookup.""" | |
| from datetime import datetime | |
| for fmt in ("%b-%d-%Y", "%B-%d-%Y", "%Y-%m-%d", "%Y%m%d"): | |
| try: | |
| dt = datetime.strptime(date_str, fmt) | |
| return dt.strftime("%Y%m%d") | |
| except ValueError: | |
| continue | |
| return None | |
| def _build_case(self, entry: dict, parsed_vignette: dict) -> MedicalCase | None: | |
| """Convert a raw NEJM entry + parsed vignette into a MedicalCase.""" | |
| case_id = entry.get("date", entry.get("id", "unknown")) | |
| # ---- Find image ---- | |
| img_b64 = None | |
| img_dir = self.data_dir / "images" | |
| # Build candidate filenames: original case_id + YYYYMMDD conversion | |
| name_candidates = [case_id] | |
| yyyymmdd = self._date_to_yyyymmdd(case_id) | |
| if yyyymmdd: | |
| name_candidates.append(yyyymmdd) | |
| if img_dir.exists(): | |
| for name in name_candidates: | |
| for ext in [".jpg", ".jpeg", ".png"]: | |
| p = img_dir / f"{name}{ext}" | |
| if p.exists(): | |
| try: | |
| img_b64 = encode_image_to_base64(p) | |
| except Exception: | |
| pass | |
| break | |
| if img_b64 is not None: | |
| break | |
| if img_b64 is None: | |
| # Glob for any match | |
| for name in name_candidates: | |
| matches = list(img_dir.glob(f"*{name}*")) | |
| if matches: | |
| try: | |
| img_b64 = encode_image_to_base64(matches[0]) | |
| except Exception: | |
| pass | |
| break | |
| # ---- Build all available channels, then split by config ---- | |
| all_channels = {} | |
| if img_b64 is not None: | |
| image_meta = config.get_channel_definition("nejm", "image") | |
| all_channels["image"] = ChannelData( | |
| name="image", | |
| channel_type="image", | |
| description="The primary diagnostic image", | |
| value=img_b64, | |
| cost=float(image_meta.get("cost", 0.0)), | |
| tier=image_meta.get("tier", "unknown"), | |
| always_given=bool(image_meta.get("always_given", False)), | |
| ) | |
| field_descriptions = { | |
| "demographics": "Patient age, sex, and ethnicity if mentioned", | |
| "chief_complaint": "The presenting symptom(s) and their duration", | |
| "medical_history": "Past medical conditions, medications, family and social history", | |
| "exam_findings": "Physical examination results and observations", | |
| "investigations": "Laboratory values, prior imaging results, and test outcomes", | |
| } | |
| for field in VIGNETTE_FIELDS: | |
| value = parsed_vignette.get(field, "Not mentioned.") | |
| field_meta = config.get_channel_definition("nejm", field) | |
| if value and value.strip() != "Not mentioned.": | |
| all_channels[field] = ChannelData( | |
| name=field, | |
| channel_type="text", | |
| description=field_descriptions.get(field, field), | |
| value=value, | |
| cost=float(field_meta.get("cost", 0.0)), | |
| tier=field_meta.get("tier", "unknown"), | |
| always_given=bool(field_meta.get("always_given", False)), | |
| ) | |
| else: | |
| all_channels[field] = ChannelData( | |
| name=field, | |
| channel_type="text", | |
| description=field_descriptions.get(field, field), | |
| value="No additional information available for this category.", | |
| cost=float(field_meta.get("cost", 0.0)), | |
| tier=field_meta.get("tier", "unknown"), | |
| always_given=bool(field_meta.get("always_given", False)), | |
| ) | |
| initial_channels = { | |
| name: ch for name, ch in all_channels.items() if ch.always_given | |
| } | |
| requestable = { | |
| name: ch for name, ch in all_channels.items() if not ch.always_given | |
| } | |
| if not initial_channels and not requestable: | |
| logger.debug(f"Skipping NEJM {case_id}: no usable channels found") | |
| return None | |
| # ---- Candidates: the 5 MCQ options ---- | |
| options = entry.get("options", []) | |
| correct = entry.get("correct_answer", entry.get("answer", "")) | |
| # Handle flat option_A..option_E keys (cx0/nejm-image-challenge format) | |
| if not options: | |
| flat_options = {} | |
| for letter in "ABCDE": | |
| val = entry.get(f"option_{letter}", "") | |
| if val: | |
| flat_options[letter] = val | |
| if flat_options: | |
| options = flat_options | |
| if isinstance(options, dict): | |
| # {A: "...", B: "...", ...} | |
| candidates = [f"{k}. {v}" for k, v in sorted(options.items())] | |
| gt_label = None | |
| for k, v in sorted(options.items()): | |
| if k == correct: | |
| gt_label = f"{k}. {v}" | |
| break | |
| if gt_label is None: | |
| gt_label = candidates[0] if candidates else "" | |
| elif isinstance(options, list) and options: | |
| candidates = options | |
| if isinstance(correct, int): | |
| gt_label = options[correct] if correct < len(options) else options[0] | |
| elif isinstance(correct, str) and len(correct) == 1: | |
| # Letter answer (A=0, B=1, ...) | |
| idx = ord(correct.upper()) - ord("A") | |
| gt_label = options[idx] if idx < len(options) else options[0] | |
| else: | |
| gt_label = correct | |
| else: | |
| candidates = [correct] if correct else ["Unknown"] | |
| gt_label = correct | |
| # ---- Votes (physician response distribution) ---- | |
| votes = entry.get("votes", {}) | |
| # Handle flat vote keys (option_A_votes, etc.) | |
| if not votes: | |
| for letter in "ABCDE": | |
| val = entry.get(f"option_{letter}_votes", "") | |
| if val: | |
| votes[letter] = val | |
| return MedicalCase( | |
| case_id=f"nejm_{case_id}", | |
| dataset="nejm", | |
| initial_channels=initial_channels, | |
| requestable_channels=requestable, | |
| candidates=candidates, | |
| ground_truth=gt_label, | |
| ground_truth_rank=(candidates.index(gt_label) if gt_label in candidates else 0), | |
| metadata={ | |
| "date": case_id, | |
| "votes": votes, | |
| "full_vignette": entry.get("question", entry.get("prompt", entry.get("vignette", ""))), | |
| "parsed_fields": parsed_vignette, | |
| }, | |
| ) | |
| def get_human_difficulty(self, case: MedicalCase) -> float | None: | |
| """ | |
| Compute human difficulty score from physician vote distribution. | |
| Returns: proportion of physicians who answered correctly (0-1), | |
| or None if votes unavailable. | |
| """ | |
| votes = case.metadata.get("votes", {}) | |
| if not votes: | |
| return None | |
| correct_key = case.metadata.get("date", "") | |
| # votes might be {A: 0.12, B: 0.65, ...} or {A: 120, B: 650, ...} | |
| total = sum(float(v) for v in votes.values()) | |
| if total == 0: | |
| return None | |
| # Find the correct answer key | |
| gt = case.ground_truth | |
| for key, val in votes.items(): | |
| if key in gt or gt.startswith(key): | |
| return float(val) / total if total > 1 else float(val) | |
| return None | |