File size: 20,377 Bytes
28b13fc 215ecd6 28b13fc c61f01a 8356dae 28b13fc 8356dae c61f01a 28b13fc 8356dae 28b13fc b961b41 8356dae b961b41 28b13fc 8356dae 28b13fc c61f01a 28b13fc c61f01a 28b13fc c61f01a 28b13fc c61f01a 28b13fc c61f01a 28b13fc c61f01a 28b13fc c61f01a 28b13fc c61f01a 28b13fc c61f01a 28b13fc 02426e6 28b13fc 02426e6 215ecd6 02426e6 215ecd6 02426e6 28b13fc 02426e6 215ecd6 28b13fc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 | """
dataset.py
----------
Dataset class for unified CXR instruction tuning.
Handles 3 tasks: findings generation, impression generation, VQA.
Data sources:
- MIMIC-CXR (physionet.org/content/mimic-cxr/2.1.0)
└── reports/*.txt → findings + impression labels
- MIMIC-CXR-JPG (physionet.org/content/mimic-cxr-jpg/2.0.0)
└── files/**/*.jpg → images
- MIMIC-Ext-CXR-QBA (physionet.org/content/mimic-ext-cxr-qba/1.0.0)
└── *.json → VQA pairs
NOTE: Data loading implementation left as TODO —
fill in paths and JSON structure once data is downloaded.
"""
import json
import random
from pathlib import Path
from typing import Optional, Dict, List, Literal
import torch
from torch.utils.data import Dataset
from PIL import Image
from .prompt_templates import build_training_sample
from model.rad_dino import BioViLTEncoder
TaskType = Literal["findings", "impression", "report", "vqa", "mixed"]
# ─── Output format helpers (used by both builders and evaluation) ───────────
REPORT_FINDINGS_TAG = "Findings:"
REPORT_IMPRESSION_TAG = "Impression:"
def format_merged_report(findings: Optional[str], impression: Optional[str]) -> Optional[str]:
"""
Build the target string for the merged "report" task.
Returns:
"Findings: <f>\n\nImpression: <i>" if both sections exist
"Impression: <i>" if only impression exists
None if neither (or only findings — see below)
Note: we deliberately drop samples that have ONLY findings. The merged
task is meant to teach the model to produce a full report ending with
an impression. Samples without impression would push the model to omit
the impression section, which is exactly what we do not want.
"""
has_f = findings and findings.strip()
has_i = impression and impression.strip()
if has_f and has_i:
return f"{REPORT_FINDINGS_TAG} {findings.strip()}\n\n{REPORT_IMPRESSION_TAG} {impression.strip()}"
if has_i and not has_f:
return f"{REPORT_IMPRESSION_TAG} {impression.strip()}"
return None
def parse_generated_report(text: str) -> Dict[str, str]:
"""
Split a generated/ground-truth report string into its two sections.
Tolerant to missing sections, extra whitespace, and case variations of the tag.
Returns:
{"findings": <str>, "impression": <str>} — either may be "".
"""
import re
f_match = re.search(r"findings\s*:\s*(.*?)(?=impression\s*:|$)", text,
flags=re.IGNORECASE | re.DOTALL)
i_match = re.search(r"impression\s*:\s*(.*)", text,
flags=re.IGNORECASE | re.DOTALL)
return {
"findings": f_match.group(1).strip() if f_match else "",
"impression": i_match.group(1).strip() if i_match else "",
}
class CXRInstructDataset(Dataset):
"""
Unified instruction-following dataset for chest X-ray interpretation.
Each sample contains:
image: chest X-ray tensor (C, H, W)
input_ids: tokenized prompt
attention_mask
labels: tokenized target (prompt tokens masked with -100)
Args:
data_path: path to pre-built instruction JSON (see build_instruct_json.py)
image_root: root directory of MIMIC-CXR-JPG images
tokenizer: HuggingFace tokenizer
transform: image transform (BioViLTEncoder.get_transform())
task: "findings" | "impression" | "vqa" | "mixed"
split: "train" | "validate" | "test"
cutoff_len: maximum token length
task_weights: sampling weights for mixed task {"findings": 0.4, ...}
"""
def __init__(
self,
data_path: str,
image_root: str,
tokenizer,
transform = None,
task: TaskType = "mixed",
split: str = "train",
cutoff_len: int = 512,
task_weights: Optional[Dict[str, float]] = None,
max_images: int = 1, # >1 only useful in multi_image_merged mode
feature_cache_dir: Optional[str] = None,
itc_mode: bool = False, # Stage-1 ITC: yield (image, text_embed)
itc_text_cache: Optional[str] = None, # {study_id: tensor[proj_dim]} .pt file
):
self.image_root = Path(image_root)
self.tokenizer = tokenizer
self.transform = transform
self.task = task
self.split = split
self.cutoff_len = cutoff_len
self.max_images = max(1, int(max_images))
# ── ITC Stage-1 mode ────────────────────────────────────────────────
# Yields {image, text_embed} instead of tokenized prompt/target. The
# text embeddings are PRECOMPUTED (CXR-BERT) and keyed by study_id; the
# dataset is deduped to ONE image per study so every text embedding is
# a unique InfoNCE positive (no in-batch collisions).
self.itc_mode = bool(itc_mode)
self.itc_text_embeds: Dict[str, torch.Tensor] = {}
if self.itc_mode:
if not itc_text_cache:
raise ValueError("itc_mode=True requires itc_text_cache (.pt path).")
raw = torch.load(itc_text_cache, map_location="cpu")
# accept either {study_id: tensor} or {"embeds": {...}} wrappers
self.itc_text_embeds = raw.get("embeds", raw) if isinstance(raw, dict) else raw
print(f"[CXRInstructDataset] ITC mode: loaded "
f"{len(self.itc_text_embeds):,} text embeddings ← {itc_text_cache}")
# When set, _load_image first checks {feature_cache_dir}/{relpath}.pt
# and returns the cached (P, 768) patch-feature tensor instead of the
# raw image. The model detects this by tensor last-dim and skips the
# frozen encoder entirely. Safe because no random augmentation is
# applied (Resize + ToTensor + Normalize are deterministic).
self.feature_cache_dir = Path(feature_cache_dir) if feature_cache_dir else None
self.task_weights = task_weights or {
"findings": 0.4,
"impression": 0.2,
"vqa": 0.4,
}
# Load pre-built instruction JSON
# Format: list of dicts with keys:
# image_path, task, target, question (for VQA), structured_findings
self.samples = self._load_data(data_path, split)
print(f"[{split}] Loaded {len(self.samples)} samples for task={task}")
def _load_data(self, data_path: str, split: str) -> List[Dict]:
"""Load and filter samples by split and task type."""
with open(data_path, "r") as f:
all_samples = json.load(f)
# Filter by split
samples = [s for s in all_samples if s.get("split", "train") == split]
# ── ITC mode: dedup to ONE image per study with an available text
# embedding. Each study's canonical text (findings-first, computed at
# precompute time) is a unique InfoNCE positive. ──
if getattr(self, "itc_mode", False):
seen_studies = set()
itc_samples: List[Dict] = []
n_no_embed = n_no_study = 0
for s in samples:
# study_id when present (MIMIC); else fall back to image_path
# so datasets without study ids (IU-Xray) still smoke-test.
sid = s.get("study_id") or s.get("image_path")
if not sid:
n_no_study += 1
continue
if sid in seen_studies:
continue
if sid not in self.itc_text_embeds:
n_no_embed += 1
continue
# skip merged/cached-feature-incompatible? keep single-image only
if not s.get("image_path") and not s.get("image_paths"):
continue
seen_studies.add(sid)
itc_samples.append(s)
print(f"[CXRInstructDataset] ITC[{split}]: {len(itc_samples):,} studies "
f"(dropped {n_no_embed:,} w/o text-embed, {n_no_study:,} w/o study_id)")
return itc_samples
# Filter by task
if self.task != "mixed":
samples = [s for s in samples if s["task"] == self.task]
# Sanity check: report_mode is enforced at build time, but if the
# caller asks for task=report yet the JSON was built in "split" mode
# (or vice versa) the dataset will be empty / wrong. Warn loudly.
if len(samples) == 0:
tasks_in_file = sorted({s["task"] for s in all_samples
if s.get("split", "train") == split})
print(f"[CXRInstructDataset] WARNING: 0 samples for task={self.task!r} "
f"in split={split!r}. Tasks present in JSON: {tasks_in_file}. "
f"Did you forget to rebuild the instruct JSON after changing "
f"data.report_mode in train_config.yaml?")
return samples
def __len__(self) -> int:
return len(self.samples)
def get_per_sample_weights(self) -> Optional[List[float]]:
"""
Build per-sample weights for `torch.utils.data.WeightedRandomSampler`
so that, in expectation, each task occupies its configured fraction of
drawn training samples — regardless of how many samples of each task
exist in the JSON.
Math:
For task t with N_t samples in the JSON and configured weight w_t,
give every sample of t the weight `w_t / N_t`. The aggregate
probability of drawing ANY sample of task t over one draw becomes
`N_t * (w_t / N_t) = w_t`, which is exactly the desired ratio.
Tasks with weight 0 (e.g. VQA on IU-Xray) get weight 0 → never drawn.
Tasks present in the JSON but absent from `self.task_weights` also get
weight 0 (loud-failure-on-misconfig is preferable to silent miscounts).
Returns:
list of floats of length len(self.samples), or None if this is a
single-task dataset (`self.task != "mixed"`) — in that case every
sample is the same task, so weighted sampling is unnecessary and
the default uniform `RandomSampler` is correct.
"""
if getattr(self, "itc_mode", False):
return None # ITC: uniform sampling over deduped studies
if self.task != "mixed":
return None
# Count samples per task that actually appear in this dataset.
counts: Dict[str, int] = {}
for s in self.samples:
counts[s["task"]] = counts.get(s["task"], 0) + 1
# Per-sample weight = w_task / N_task. Tasks not in task_weights → 0.
weights = [
float(self.task_weights.get(s["task"], 0.0)) / counts[s["task"]]
for s in self.samples
]
# Sanity: print effective per-task probabilities once so the actual
# mix during training is visible in logs (helps catch misconfigured
# weights vs. JSON-task-set mismatch).
eff = {t: float(self.task_weights.get(t, 0.0)) for t in counts}
eff_sum = sum(eff.values()) or 1.0
eff = {t: round(v / eff_sum, 4) for t, v in eff.items()}
print(f"[CXRInstructDataset] WeightedRandomSampler effective task mix: "
f"{eff} (counts: {counts})")
return weights
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
sample = self.samples[idx]
# ── Load image(s) ───────────────────────────────────────────────────
# In single-image modes ("image_path" set), `image` is (C, H, W).
# In multi-image mode ("image_paths" set), `image` is (max_images, C, H, W)
# padded/truncated to self.max_images by duplicating the first view.
if sample.get("image_paths"):
image = self._load_image_stack(sample["image_paths"]) # (N, C, H, W)
else:
image = self._load_image(sample["image_path"]) # (C, H, W)
# ── ITC mode: return (image, precomputed text embedding) ─────────────
if getattr(self, "itc_mode", False):
key = sample.get("study_id") or sample["image_path"]
text_embed = self.itc_text_embeds[key].float() # (proj_dim,)
return {"image": image, "text_embed": text_embed}
# ── Build prompt + target ────────────────────────────────────────────
training_sample = build_training_sample(
task = sample["task"],
target = sample["target"],
question = sample.get("question"), # VQA only
structured_findings = sample.get("structured_findings"), # optional
randomize = (self.split == "train"), # only randomize in train
)
# ── Tokenize ─────────────────────────────────────────────────────────
input_ids, labels = self._tokenize(
prompt = training_sample["prompt"],
target = training_sample["target"],
)
# `input_ids` is un-padded at this point. The collator builds the
# final attention_mask (1 for real tokens, 0 for batch-level
# padding) after stacking. We pass an all-ones placeholder of the
# correct per-sample length so other parts of the pipeline that
# peek at the dataset output still see the expected key.
return {
"image": image,
"input_ids": input_ids,
"attention_mask": torch.ones_like(input_ids),
"labels": labels,
"task": sample["task"], # for per-task logging
}
def _load_image_stack(self, image_paths: List[str]) -> torch.Tensor:
"""
Load a list of images and stack them into (max_images, C, H, W).
Padding rule: if the study has fewer images than max_images, duplicate
the FIRST view (frontal) until we hit max_images. Truncation rule:
keep the first max_images views in the order the builder emitted.
This keeps every multi-image sample the SAME tensor shape so the
default torch.utils.data collator can batch them without a custom
collator. The downside is wasted compute on duplicated views; if you
need a per-sample N, replace with a custom collator that pads inside
the batch.
"""
if not image_paths:
raise ValueError(f"empty image_paths for sample")
imgs = [self._load_image(p) for p in image_paths[: self.max_images]]
while len(imgs) < self.max_images:
imgs.append(imgs[0].clone()) # pad with frontal
return torch.stack(imgs, dim=0) # (N, C, H, W)
def _load_image(self, image_path: str) -> torch.Tensor:
"""
Load and transform a chest X-ray image, or load pre-computed patch
features when a feature cache hits.
Args:
image_path: relative path from image_root
e.g. "files/p10/p10000032/s50414267/02aa804e.jpg"
Returns either (C, H, W) for raw images or (P, 768) when a cached
feature file exists at `{feature_cache_dir}/{image_path}.pt`.
"""
# ── Fast path: pre-computed patch features ────────────────────────
if self.feature_cache_dir is not None:
cache_path = self.feature_cache_dir / (image_path + ".pt")
if cache_path.is_file():
# Load tensor — should be (P, 768) for a single image.
return torch.load(cache_path, map_location="cpu", weights_only=True)
# ── Slow path: read JPEG + transform ──────────────────────────────
full_path = self.image_root / image_path
image = Image.open(full_path).convert("RGB")
if self.transform is not None:
image = self.transform(image)
else:
# Fallback basic transform if BioViL-T transform not provided
from torchvision import transforms
fallback = transforms.Compose([
transforms.Resize((448, 448)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
image = fallback(image)
return image
def _tokenize(
self,
prompt: str,
target: str,
) -> tuple:
"""
Tokenize prompt+target. Labels have -100 for prompt tokens
(so loss is only computed on target tokens).
Returns UN-PADDED tensors of variable length ≤ cutoff_len. The
CXRDataCollator pads them to the max length within each batch,
so short batches (e.g. all-VQA) skip the padded compute entirely.
Per-sample padding is no longer applied here; that masking lives
in the collator now.
Returns:
input_ids: (L,) L ≤ cutoff_len
labels: (L,) with -100 for prompt positions
"""
full_text = prompt + " " + target
prompt_encoded = self.tokenizer.encode(prompt, add_special_tokens=True)
full_encoded = self.tokenizer.encode(
full_text,
add_special_tokens = True,
max_length = self.cutoff_len,
truncation = True,
# No padding — collator pads to max-in-batch (dynamic padding).
)
input_ids = torch.tensor(full_encoded, dtype=torch.long)
# Labels: -100 for prompt tokens, actual token ids for target tokens.
# (Padding masking now happens in the collator since there's no
# padding at this stage.)
labels = input_ids.clone()
prompt_len = min(len(prompt_encoded), self.cutoff_len)
labels[:prompt_len] = -100
return input_ids, labels
# ─── Helper: Build Instruction JSON ─────────────────────────────────────────
def build_instruct_json(
mimic_cxr_root: str,
output_path: str,
chexpert_csv: Optional[str] = None,
vqa_data_root: Optional[str] = None,
report_mode: str = "split",
image_mode: str = "all_views_split",
) -> str:
"""
Build the unified MIMIC-CXR instruction JSON.
Thin delegate to `data.mimic_cxr_builder.build_mimic_cxr_instruct_json`,
which walks the pre-split MIMIC layout (train/valid/test), parses
findings/impression from the report .txt files, and bakes the 14 CheXpert
labels (oracle, from `*chexpert*.csv`) into `structured_findings` as the
PNU 3-section string (U-MultiClass, META-CXR format) — the RaDialog
image + abnormality-guidance setup. `report_mode` / `image_mode` mirror
the IU builder.
Output entries match the shared schema, e.g.:
{"image_path": "train/p10/p10000032/s50414267/02aa804e.jpg",
"task": "findings", "target": "The lungs are clear...",
"question": null,
"structured_findings": "Positive Abnormalities: None\\n
Negative Abnormalities: No Finding, ...\\n
Uncertain Abnormalities: None",
"split": "train", "study_id": "s50414267",
"subject_id": "p10000032"}
"""
from .mimic_cxr_builder import build_mimic_cxr_instruct_json
return build_mimic_cxr_instruct_json(
mimic_root = mimic_cxr_root,
output_path = output_path,
chexpert_csv = chexpert_csv,
vqa_root = vqa_data_root,
report_mode = report_mode,
image_mode = image_mode,
)
|