cxr-vlm-code / data /dataset.py
convitom
f
8356dae
"""
dataset.py
----------
Dataset class for unified CXR instruction tuning.
Handles 3 tasks: findings generation, impression generation, VQA.
Data sources:
- MIMIC-CXR (physionet.org/content/mimic-cxr/2.1.0)
└── reports/*.txt → findings + impression labels
- MIMIC-CXR-JPG (physionet.org/content/mimic-cxr-jpg/2.0.0)
└── files/**/*.jpg → images
- MIMIC-Ext-CXR-QBA (physionet.org/content/mimic-ext-cxr-qba/1.0.0)
└── *.json → VQA pairs
NOTE: Data loading implementation left as TODO —
fill in paths and JSON structure once data is downloaded.
"""
import json
import random
from pathlib import Path
from typing import Optional, Dict, List, Literal
import torch
from torch.utils.data import Dataset
from PIL import Image
from .prompt_templates import build_training_sample
from model.rad_dino import BioViLTEncoder
TaskType = Literal["findings", "impression", "report", "vqa", "mixed"]
# ─── Output format helpers (used by both builders and evaluation) ───────────
REPORT_FINDINGS_TAG = "Findings:"
REPORT_IMPRESSION_TAG = "Impression:"
def format_merged_report(findings: Optional[str], impression: Optional[str]) -> Optional[str]:
"""
Build the target string for the merged "report" task.
Returns:
"Findings: <f>\n\nImpression: <i>" if both sections exist
"Impression: <i>" if only impression exists
None if neither (or only findings — see below)
Note: we deliberately drop samples that have ONLY findings. The merged
task is meant to teach the model to produce a full report ending with
an impression. Samples without impression would push the model to omit
the impression section, which is exactly what we do not want.
"""
has_f = findings and findings.strip()
has_i = impression and impression.strip()
if has_f and has_i:
return f"{REPORT_FINDINGS_TAG} {findings.strip()}\n\n{REPORT_IMPRESSION_TAG} {impression.strip()}"
if has_i and not has_f:
return f"{REPORT_IMPRESSION_TAG} {impression.strip()}"
return None
def parse_generated_report(text: str) -> Dict[str, str]:
"""
Split a generated/ground-truth report string into its two sections.
Tolerant to missing sections, extra whitespace, and case variations of the tag.
Returns:
{"findings": <str>, "impression": <str>} — either may be "".
"""
import re
f_match = re.search(r"findings\s*:\s*(.*?)(?=impression\s*:|$)", text,
flags=re.IGNORECASE | re.DOTALL)
i_match = re.search(r"impression\s*:\s*(.*)", text,
flags=re.IGNORECASE | re.DOTALL)
return {
"findings": f_match.group(1).strip() if f_match else "",
"impression": i_match.group(1).strip() if i_match else "",
}
class CXRInstructDataset(Dataset):
"""
Unified instruction-following dataset for chest X-ray interpretation.
Each sample contains:
image: chest X-ray tensor (C, H, W)
input_ids: tokenized prompt
attention_mask
labels: tokenized target (prompt tokens masked with -100)
Args:
data_path: path to pre-built instruction JSON (see build_instruct_json.py)
image_root: root directory of MIMIC-CXR-JPG images
tokenizer: HuggingFace tokenizer
transform: image transform (BioViLTEncoder.get_transform())
task: "findings" | "impression" | "vqa" | "mixed"
split: "train" | "validate" | "test"
cutoff_len: maximum token length
task_weights: sampling weights for mixed task {"findings": 0.4, ...}
"""
def __init__(
self,
data_path: str,
image_root: str,
tokenizer,
transform = None,
task: TaskType = "mixed",
split: str = "train",
cutoff_len: int = 512,
task_weights: Optional[Dict[str, float]] = None,
max_images: int = 1, # >1 only useful in multi_image_merged mode
feature_cache_dir: Optional[str] = None,
itc_mode: bool = False, # Stage-1 ITC: yield (image, text_embed)
itc_text_cache: Optional[str] = None, # {study_id: tensor[proj_dim]} .pt file
):
self.image_root = Path(image_root)
self.tokenizer = tokenizer
self.transform = transform
self.task = task
self.split = split
self.cutoff_len = cutoff_len
self.max_images = max(1, int(max_images))
# ── ITC Stage-1 mode ────────────────────────────────────────────────
# Yields {image, text_embed} instead of tokenized prompt/target. The
# text embeddings are PRECOMPUTED (CXR-BERT) and keyed by study_id; the
# dataset is deduped to ONE image per study so every text embedding is
# a unique InfoNCE positive (no in-batch collisions).
self.itc_mode = bool(itc_mode)
self.itc_text_embeds: Dict[str, torch.Tensor] = {}
if self.itc_mode:
if not itc_text_cache:
raise ValueError("itc_mode=True requires itc_text_cache (.pt path).")
raw = torch.load(itc_text_cache, map_location="cpu")
# accept either {study_id: tensor} or {"embeds": {...}} wrappers
self.itc_text_embeds = raw.get("embeds", raw) if isinstance(raw, dict) else raw
print(f"[CXRInstructDataset] ITC mode: loaded "
f"{len(self.itc_text_embeds):,} text embeddings ← {itc_text_cache}")
# When set, _load_image first checks {feature_cache_dir}/{relpath}.pt
# and returns the cached (P, 768) patch-feature tensor instead of the
# raw image. The model detects this by tensor last-dim and skips the
# frozen encoder entirely. Safe because no random augmentation is
# applied (Resize + ToTensor + Normalize are deterministic).
self.feature_cache_dir = Path(feature_cache_dir) if feature_cache_dir else None
self.task_weights = task_weights or {
"findings": 0.4,
"impression": 0.2,
"vqa": 0.4,
}
# Load pre-built instruction JSON
# Format: list of dicts with keys:
# image_path, task, target, question (for VQA), structured_findings
self.samples = self._load_data(data_path, split)
print(f"[{split}] Loaded {len(self.samples)} samples for task={task}")
def _load_data(self, data_path: str, split: str) -> List[Dict]:
"""Load and filter samples by split and task type."""
with open(data_path, "r") as f:
all_samples = json.load(f)
# Filter by split
samples = [s for s in all_samples if s.get("split", "train") == split]
# ── ITC mode: dedup to ONE image per study with an available text
# embedding. Each study's canonical text (findings-first, computed at
# precompute time) is a unique InfoNCE positive. ──
if getattr(self, "itc_mode", False):
seen_studies = set()
itc_samples: List[Dict] = []
n_no_embed = n_no_study = 0
for s in samples:
# study_id when present (MIMIC); else fall back to image_path
# so datasets without study ids (IU-Xray) still smoke-test.
sid = s.get("study_id") or s.get("image_path")
if not sid:
n_no_study += 1
continue
if sid in seen_studies:
continue
if sid not in self.itc_text_embeds:
n_no_embed += 1
continue
# skip merged/cached-feature-incompatible? keep single-image only
if not s.get("image_path") and not s.get("image_paths"):
continue
seen_studies.add(sid)
itc_samples.append(s)
print(f"[CXRInstructDataset] ITC[{split}]: {len(itc_samples):,} studies "
f"(dropped {n_no_embed:,} w/o text-embed, {n_no_study:,} w/o study_id)")
return itc_samples
# Filter by task
if self.task != "mixed":
samples = [s for s in samples if s["task"] == self.task]
# Sanity check: report_mode is enforced at build time, but if the
# caller asks for task=report yet the JSON was built in "split" mode
# (or vice versa) the dataset will be empty / wrong. Warn loudly.
if len(samples) == 0:
tasks_in_file = sorted({s["task"] for s in all_samples
if s.get("split", "train") == split})
print(f"[CXRInstructDataset] WARNING: 0 samples for task={self.task!r} "
f"in split={split!r}. Tasks present in JSON: {tasks_in_file}. "
f"Did you forget to rebuild the instruct JSON after changing "
f"data.report_mode in train_config.yaml?")
return samples
def __len__(self) -> int:
return len(self.samples)
def get_per_sample_weights(self) -> Optional[List[float]]:
"""
Build per-sample weights for `torch.utils.data.WeightedRandomSampler`
so that, in expectation, each task occupies its configured fraction of
drawn training samples — regardless of how many samples of each task
exist in the JSON.
Math:
For task t with N_t samples in the JSON and configured weight w_t,
give every sample of t the weight `w_t / N_t`. The aggregate
probability of drawing ANY sample of task t over one draw becomes
`N_t * (w_t / N_t) = w_t`, which is exactly the desired ratio.
Tasks with weight 0 (e.g. VQA on IU-Xray) get weight 0 → never drawn.
Tasks present in the JSON but absent from `self.task_weights` also get
weight 0 (loud-failure-on-misconfig is preferable to silent miscounts).
Returns:
list of floats of length len(self.samples), or None if this is a
single-task dataset (`self.task != "mixed"`) — in that case every
sample is the same task, so weighted sampling is unnecessary and
the default uniform `RandomSampler` is correct.
"""
if getattr(self, "itc_mode", False):
return None # ITC: uniform sampling over deduped studies
if self.task != "mixed":
return None
# Count samples per task that actually appear in this dataset.
counts: Dict[str, int] = {}
for s in self.samples:
counts[s["task"]] = counts.get(s["task"], 0) + 1
# Per-sample weight = w_task / N_task. Tasks not in task_weights → 0.
weights = [
float(self.task_weights.get(s["task"], 0.0)) / counts[s["task"]]
for s in self.samples
]
# Sanity: print effective per-task probabilities once so the actual
# mix during training is visible in logs (helps catch misconfigured
# weights vs. JSON-task-set mismatch).
eff = {t: float(self.task_weights.get(t, 0.0)) for t in counts}
eff_sum = sum(eff.values()) or 1.0
eff = {t: round(v / eff_sum, 4) for t, v in eff.items()}
print(f"[CXRInstructDataset] WeightedRandomSampler effective task mix: "
f"{eff} (counts: {counts})")
return weights
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
sample = self.samples[idx]
# ── Load image(s) ───────────────────────────────────────────────────
# In single-image modes ("image_path" set), `image` is (C, H, W).
# In multi-image mode ("image_paths" set), `image` is (max_images, C, H, W)
# padded/truncated to self.max_images by duplicating the first view.
if sample.get("image_paths"):
image = self._load_image_stack(sample["image_paths"]) # (N, C, H, W)
else:
image = self._load_image(sample["image_path"]) # (C, H, W)
# ── ITC mode: return (image, precomputed text embedding) ─────────────
if getattr(self, "itc_mode", False):
key = sample.get("study_id") or sample["image_path"]
text_embed = self.itc_text_embeds[key].float() # (proj_dim,)
return {"image": image, "text_embed": text_embed}
# ── Build prompt + target ────────────────────────────────────────────
training_sample = build_training_sample(
task = sample["task"],
target = sample["target"],
question = sample.get("question"), # VQA only
structured_findings = sample.get("structured_findings"), # optional
randomize = (self.split == "train"), # only randomize in train
)
# ── Tokenize ─────────────────────────────────────────────────────────
input_ids, labels = self._tokenize(
prompt = training_sample["prompt"],
target = training_sample["target"],
)
# `input_ids` is un-padded at this point. The collator builds the
# final attention_mask (1 for real tokens, 0 for batch-level
# padding) after stacking. We pass an all-ones placeholder of the
# correct per-sample length so other parts of the pipeline that
# peek at the dataset output still see the expected key.
return {
"image": image,
"input_ids": input_ids,
"attention_mask": torch.ones_like(input_ids),
"labels": labels,
"task": sample["task"], # for per-task logging
}
def _load_image_stack(self, image_paths: List[str]) -> torch.Tensor:
"""
Load a list of images and stack them into (max_images, C, H, W).
Padding rule: if the study has fewer images than max_images, duplicate
the FIRST view (frontal) until we hit max_images. Truncation rule:
keep the first max_images views in the order the builder emitted.
This keeps every multi-image sample the SAME tensor shape so the
default torch.utils.data collator can batch them without a custom
collator. The downside is wasted compute on duplicated views; if you
need a per-sample N, replace with a custom collator that pads inside
the batch.
"""
if not image_paths:
raise ValueError(f"empty image_paths for sample")
imgs = [self._load_image(p) for p in image_paths[: self.max_images]]
while len(imgs) < self.max_images:
imgs.append(imgs[0].clone()) # pad with frontal
return torch.stack(imgs, dim=0) # (N, C, H, W)
def _load_image(self, image_path: str) -> torch.Tensor:
"""
Load and transform a chest X-ray image, or load pre-computed patch
features when a feature cache hits.
Args:
image_path: relative path from image_root
e.g. "files/p10/p10000032/s50414267/02aa804e.jpg"
Returns either (C, H, W) for raw images or (P, 768) when a cached
feature file exists at `{feature_cache_dir}/{image_path}.pt`.
"""
# ── Fast path: pre-computed patch features ────────────────────────
if self.feature_cache_dir is not None:
cache_path = self.feature_cache_dir / (image_path + ".pt")
if cache_path.is_file():
# Load tensor — should be (P, 768) for a single image.
return torch.load(cache_path, map_location="cpu", weights_only=True)
# ── Slow path: read JPEG + transform ──────────────────────────────
full_path = self.image_root / image_path
image = Image.open(full_path).convert("RGB")
if self.transform is not None:
image = self.transform(image)
else:
# Fallback basic transform if BioViL-T transform not provided
from torchvision import transforms
fallback = transforms.Compose([
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
image = fallback(image)
return image
def _tokenize(
self,
prompt: str,
target: str,
) -> tuple:
"""
Tokenize prompt+target. Labels have -100 for prompt tokens
(so loss is only computed on target tokens).
Returns UN-PADDED tensors of variable length ≤ cutoff_len. The
CXRDataCollator pads them to the max length within each batch,
so short batches (e.g. all-VQA) skip the padded compute entirely.
Per-sample padding is no longer applied here; that masking lives
in the collator now.
Returns:
input_ids: (L,) L ≤ cutoff_len
labels: (L,) with -100 for prompt positions
"""
full_text = prompt + " " + target
prompt_encoded = self.tokenizer.encode(prompt, add_special_tokens=True)
full_encoded = self.tokenizer.encode(
full_text,
add_special_tokens = True,
max_length = self.cutoff_len,
truncation = True,
# No padding — collator pads to max-in-batch (dynamic padding).
)
input_ids = torch.tensor(full_encoded, dtype=torch.long)
# Labels: -100 for prompt tokens, actual token ids for target tokens.
# (Padding masking now happens in the collator since there's no
# padding at this stage.)
labels = input_ids.clone()
prompt_len = min(len(prompt_encoded), self.cutoff_len)
labels[:prompt_len] = -100
return input_ids, labels
# ─── Helper: Build Instruction JSON ─────────────────────────────────────────
def build_instruct_json(
mimic_cxr_root: str,
output_path: str,
chexpert_csv: Optional[str] = None,
vqa_data_root: Optional[str] = None,
report_mode: str = "split",
image_mode: str = "all_views_split",
) -> str:
"""
Build the unified MIMIC-CXR instruction JSON.
Thin delegate to `data.mimic_cxr_builder.build_mimic_cxr_instruct_json`,
which walks the pre-split MIMIC layout (train/valid/test), parses
findings/impression from the report .txt files, and bakes the 14 CheXpert
labels (oracle, from `*chexpert*.csv`) into `structured_findings` as the
PNU 3-section string (U-MultiClass, META-CXR format) — the RaDialog
image + abnormality-guidance setup. `report_mode` / `image_mode` mirror
the IU builder.
Output entries match the shared schema, e.g.:
{"image_path": "train/p10/p10000032/s50414267/02aa804e.jpg",
"task": "findings", "target": "The lungs are clear...",
"question": null,
"structured_findings": "Positive Abnormalities: None\\n
Negative Abnormalities: No Finding, ...\\n
Uncertain Abnormalities: None",
"split": "train", "study_id": "s50414267",
"subject_id": "p10000032"}
"""
from .mimic_cxr_builder import build_mimic_cxr_instruct_json
return build_mimic_cxr_instruct_json(
mimic_root = mimic_cxr_root,
output_path = output_path,
chexpert_csv = chexpert_csv,
vqa_root = vqa_data_root,
report_mode = report_mode,
image_mode = image_mode,
)