yuxbox's picture
Upload folder using huggingface_hub
a1aaf30 verified
"""
NEJM Image Challenge Dataset Loader.
Expects the cx0/nejm-image-challenge dataset structure:
nejm/
β”œβ”€β”€ data.json (or nejm_data.json)
β”‚ Each entry: {date, image_url, prompt (clinical vignette),
β”‚ options [A..E], correct_answer, votes}
β”œβ”€β”€ images/ (downloaded images, named by date YYYYMMDD.jpg)
└── parsed_vignettes.json (pre-parsed structured fields, optional)
The clinical vignette is decomposed into 5 requestable text channels
using LLM-based parsing (see scripts/parse_nejm_vignettes.py).
"""
import json
import logging
import random
import re
from pathlib import Path
from .base import DatasetBase, MedicalCase, ChannelData
from api_client import encode_image_to_base64
import config
logger = logging.getLogger(__name__)
# ---- Vignette parsing schema ----
VIGNETTE_FIELDS = [
"demographics",
"chief_complaint",
"medical_history",
"exam_findings",
"investigations",
]
VIGNETTE_PARSE_PROMPT = """You are a medical data extraction system. Parse the following clinical \
vignette into exactly 5 structured fields. Extract ONLY information that is explicitly stated. \
If a field has no relevant information, write "Not mentioned."
FIELDS:
1. demographics: Patient age, sex, race/ethnicity if stated.
2. chief_complaint: The primary presenting symptom(s) and their duration.
3. medical_history: Past medical conditions, medications, surgical history, family history, social history (smoking, alcohol, etc.).
4. exam_findings: Physical examination findings, vital signs.
5. investigations: Laboratory results, imaging findings, test results (anything with numbers or test names).
CLINICAL VIGNETTE:
{vignette}
Respond in EXACTLY this JSON format (no markdown, no extra text):
{{"demographics": "...", "chief_complaint": "...", "medical_history": "...", "exam_findings": "...", "investigations": "..."}}"""
class NEJMDataset(DatasetBase):
"""Loader for NEJM Image Challenge dataset."""
def __init__(
self,
data_dir: str | Path = None,
split: str = "test",
vlm_client=None,
use_cached_parse: bool = True,
):
super().__init__(data_dir or config.DATASET_PATHS["nejm"], split)
self.vlm_client = vlm_client
self.use_cached_parse = use_cached_parse
self._parsed_cache_path = self.data_dir / "parsed_vignettes.json"
def get_name(self) -> str:
return "nejm"
def load(self) -> list[MedicalCase]:
logger.info(f"Loading NEJM dataset from {self.data_dir}")
# ---- Load raw data ----
raw_data = self._load_raw_data()
if not raw_data:
return []
logger.info(f"Found {len(raw_data)} NEJM cases")
# ---- Load or create parsed vignettes ----
parsed = self._load_or_parse_vignettes(raw_data)
# ---- Build cases ----
self.cases = []
for entry in raw_data:
case_id = entry.get("date", entry.get("id", "unknown"))
case = self._build_case(entry, parsed.get(case_id, {}))
if case is not None:
self.cases.append(case)
logger.info(f"Loaded {len(self.cases)} NEJM cases")
return self.cases
def _load_raw_data(self) -> list[dict]:
"""Load the raw NEJM dataset JSON."""
for name in ["data.json", "nejm_data.json", "nejm.json", "dataset.json"]:
p = self.data_dir / name
if p.exists():
with open(p, encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
# Handle {date: entry} format
return [{"date": k, **v} if isinstance(v, dict) else v
for k, v in data.items()]
return data
# Try loading all JSON files
jsons = list(self.data_dir.glob("*.json"))
if jsons:
with open(jsons[0], encoding="utf-8") as f:
return json.load(f)
logger.error(f"No data file found in {self.data_dir}")
return []
def _load_or_parse_vignettes(self, raw_data: list[dict]) -> dict:
"""Load cached parsed vignettes or parse them with LLM."""
# Try cache first
if self.use_cached_parse and self._parsed_cache_path.exists():
logger.info(f"Loading cached vignette parses from {self._parsed_cache_path}")
with open(self._parsed_cache_path) as f:
return json.load(f)
# Parse with LLM if client is available
if self.vlm_client is not None:
logger.info("Parsing vignettes with LLM (this may take a while)...")
parsed = {}
for entry in raw_data:
case_id = entry.get("date", entry.get("id", "unknown"))
vignette = entry.get("question", entry.get("prompt", entry.get("vignette", "")))
if vignette:
parsed[case_id] = self._parse_vignette_with_llm(vignette)
# Cache results
with open(self._parsed_cache_path, "w") as f:
json.dump(parsed, f, indent=2)
logger.info(f"Cached {len(parsed)} parsed vignettes")
return parsed
# Fallback: rule-based parsing
logger.info("No LLM client available. Using rule-based vignette parsing (less accurate).")
parsed = {}
for entry in raw_data:
case_id = entry.get("date", entry.get("id", "unknown"))
vignette = entry.get("question", entry.get("prompt", entry.get("vignette", "")))
if vignette:
parsed[case_id] = self._parse_vignette_rules(vignette)
return parsed
def _parse_vignette_with_llm(self, vignette: str) -> dict:
"""Parse a single vignette using the LLM API."""
prompt = VIGNETTE_PARSE_PROMPT.format(vignette=vignette)
try:
response = self.vlm_client.call_with_retry(
system_prompt="You are a medical data extraction system. Respond only with valid JSON.",
user_text=prompt,
images=None,
temperature=0.0,
max_tokens=1024,
)
# Parse JSON from response
text = response.text.strip()
# Strip markdown code fences if present
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```$", "", text)
parsed = json.loads(text)
# Validate expected fields
for field in VIGNETTE_FIELDS:
if field not in parsed:
parsed[field] = "Not mentioned."
return parsed
except Exception as e:
logger.warning(f"LLM vignette parsing failed: {e}. Falling back to rules.")
return self._parse_vignette_rules(vignette)
def _parse_vignette_rules(self, vignette: str) -> dict:
"""
Rule-based fallback for vignette parsing.
Uses heuristic sentence classification.
"""
result = {f: "" for f in VIGNETTE_FIELDS}
sentences = re.split(r'(?<=[.!?])\s+', vignette)
# Patterns for classification
demo_pattern = re.compile(
r'\b(\d{1,3})[-\s]year[-\s]old\b|'
r'\b(male|female|man|woman|boy|girl)\b',
re.IGNORECASE,
)
complaint_pattern = re.compile(
r'\bpresent(?:s|ed|ing)\b|\bcomplain(?:s|ed|ing)\b|\breport(?:s|ed|ing)\b|'
r'\bseek(?:s|ing)\b|\badmitted\b',
re.IGNORECASE,
)
history_pattern = re.compile(
r'\bhistory\b|\bprevious(?:ly)?\b|\bmedication\b|\btaking\b|\bdiagnosed\b|'
r'\bsmok(?:es|ing|er)\b|\balcohol\b|\bfamily\b|\bsurgery\b',
re.IGNORECASE,
)
exam_pattern = re.compile(
r'\bexamination\b|\bexam\b|\bpalpat(?:ion|ed)\b|\bauscult(?:ation|ed)\b|'
r'\bvital\b|\bblood\s+pressure\b|\bheart\s+rate\b|\btemperature\b|'
r'\bappears\b|\btender\b|\bswollen\b|\berythema\b',
re.IGNORECASE,
)
invest_pattern = re.compile(
r'\b(?:hemoglobin|WBC|platelet|creatinine|BUN|glucose|sodium|potassium)\b|'
r'\b(?:CT|MRI|X[-\s]?ray|ultrasound|ECG|EKG|biopsy)\b|'
r'\b\d+\.?\d*\s*(?:mg|g|mL|mmol|mEq|U|IU|mmHg|\/dL|\/L)\b|'
r'\blaboratory\b|\blab(?:s)?\b|\btest\b|\blevel\b|\bfinding\b',
re.IGNORECASE,
)
for sent in sentences:
sent = sent.strip()
if not sent:
continue
# Demographics: typically the first sentence
if demo_pattern.search(sent) and not result["demographics"]:
result["demographics"] = sent
continue
# Check each pattern (a sentence can match multiple, take first)
matched = False
for field, pattern in [
("investigations", invest_pattern),
("exam_findings", exam_pattern),
("medical_history", history_pattern),
("chief_complaint", complaint_pattern),
]:
if pattern.search(sent):
if result[field]:
result[field] += " " + sent
else:
result[field] = sent
matched = True
break
# Unmatched sentences go to chief_complaint as default
if not matched:
if result["chief_complaint"]:
result["chief_complaint"] += " " + sent
else:
result["chief_complaint"] = sent
# Replace empty fields
for field in VIGNETTE_FIELDS:
if not result[field].strip():
result[field] = "Not mentioned."
return result
@staticmethod
def _date_to_yyyymmdd(date_str: str) -> str | None:
"""Convert 'apr-01-2010' style date to '20100401' for image lookup."""
from datetime import datetime
for fmt in ("%b-%d-%Y", "%B-%d-%Y", "%Y-%m-%d", "%Y%m%d"):
try:
dt = datetime.strptime(date_str, fmt)
return dt.strftime("%Y%m%d")
except ValueError:
continue
return None
def _build_case(self, entry: dict, parsed_vignette: dict) -> MedicalCase | None:
"""Convert a raw NEJM entry + parsed vignette into a MedicalCase."""
case_id = entry.get("date", entry.get("id", "unknown"))
# ---- Find image ----
img_b64 = None
img_dir = self.data_dir / "images"
# Build candidate filenames: original case_id + YYYYMMDD conversion
name_candidates = [case_id]
yyyymmdd = self._date_to_yyyymmdd(case_id)
if yyyymmdd:
name_candidates.append(yyyymmdd)
if img_dir.exists():
for name in name_candidates:
for ext in [".jpg", ".jpeg", ".png"]:
p = img_dir / f"{name}{ext}"
if p.exists():
try:
img_b64 = encode_image_to_base64(p)
except Exception:
pass
break
if img_b64 is not None:
break
if img_b64 is None:
# Glob for any match
for name in name_candidates:
matches = list(img_dir.glob(f"*{name}*"))
if matches:
try:
img_b64 = encode_image_to_base64(matches[0])
except Exception:
pass
break
# ---- Build all available channels, then split by config ----
all_channels = {}
if img_b64 is not None:
image_meta = config.get_channel_definition("nejm", "image")
all_channels["image"] = ChannelData(
name="image",
channel_type="image",
description="The primary diagnostic image",
value=img_b64,
cost=float(image_meta.get("cost", 0.0)),
tier=image_meta.get("tier", "unknown"),
always_given=bool(image_meta.get("always_given", False)),
)
field_descriptions = {
"demographics": "Patient age, sex, and ethnicity if mentioned",
"chief_complaint": "The presenting symptom(s) and their duration",
"medical_history": "Past medical conditions, medications, family and social history",
"exam_findings": "Physical examination results and observations",
"investigations": "Laboratory values, prior imaging results, and test outcomes",
}
for field in VIGNETTE_FIELDS:
value = parsed_vignette.get(field, "Not mentioned.")
field_meta = config.get_channel_definition("nejm", field)
if value and value.strip() != "Not mentioned.":
all_channels[field] = ChannelData(
name=field,
channel_type="text",
description=field_descriptions.get(field, field),
value=value,
cost=float(field_meta.get("cost", 0.0)),
tier=field_meta.get("tier", "unknown"),
always_given=bool(field_meta.get("always_given", False)),
)
else:
all_channels[field] = ChannelData(
name=field,
channel_type="text",
description=field_descriptions.get(field, field),
value="No additional information available for this category.",
cost=float(field_meta.get("cost", 0.0)),
tier=field_meta.get("tier", "unknown"),
always_given=bool(field_meta.get("always_given", False)),
)
initial_channels = {
name: ch for name, ch in all_channels.items() if ch.always_given
}
requestable = {
name: ch for name, ch in all_channels.items() if not ch.always_given
}
if not initial_channels and not requestable:
logger.debug(f"Skipping NEJM {case_id}: no usable channels found")
return None
# ---- Candidates: the 5 MCQ options ----
options = entry.get("options", [])
correct = entry.get("correct_answer", entry.get("answer", ""))
# Handle flat option_A..option_E keys (cx0/nejm-image-challenge format)
if not options:
flat_options = {}
for letter in "ABCDE":
val = entry.get(f"option_{letter}", "")
if val:
flat_options[letter] = val
if flat_options:
options = flat_options
if isinstance(options, dict):
# {A: "...", B: "...", ...}
candidates = [f"{k}. {v}" for k, v in sorted(options.items())]
gt_label = None
for k, v in sorted(options.items()):
if k == correct:
gt_label = f"{k}. {v}"
break
if gt_label is None:
gt_label = candidates[0] if candidates else ""
elif isinstance(options, list) and options:
candidates = options
if isinstance(correct, int):
gt_label = options[correct] if correct < len(options) else options[0]
elif isinstance(correct, str) and len(correct) == 1:
# Letter answer (A=0, B=1, ...)
idx = ord(correct.upper()) - ord("A")
gt_label = options[idx] if idx < len(options) else options[0]
else:
gt_label = correct
else:
candidates = [correct] if correct else ["Unknown"]
gt_label = correct
# ---- Votes (physician response distribution) ----
votes = entry.get("votes", {})
# Handle flat vote keys (option_A_votes, etc.)
if not votes:
for letter in "ABCDE":
val = entry.get(f"option_{letter}_votes", "")
if val:
votes[letter] = val
return MedicalCase(
case_id=f"nejm_{case_id}",
dataset="nejm",
initial_channels=initial_channels,
requestable_channels=requestable,
candidates=candidates,
ground_truth=gt_label,
ground_truth_rank=(candidates.index(gt_label) if gt_label in candidates else 0),
metadata={
"date": case_id,
"votes": votes,
"full_vignette": entry.get("question", entry.get("prompt", entry.get("vignette", ""))),
"parsed_fields": parsed_vignette,
},
)
def get_human_difficulty(self, case: MedicalCase) -> float | None:
"""
Compute human difficulty score from physician vote distribution.
Returns: proportion of physicians who answered correctly (0-1),
or None if votes unavailable.
"""
votes = case.metadata.get("votes", {})
if not votes:
return None
correct_key = case.metadata.get("date", "")
# votes might be {A: 0.12, B: 0.65, ...} or {A: 120, B: 650, ...}
total = sum(float(v) for v in votes.values())
if total == 0:
return None
# Find the correct answer key
gt = case.ground_truth
for key, val in votes.items():
if key in gt or gt.startswith(key):
return float(val) / total if total > 1 else float(val)
return None