""" dataset.py ---------- Dataset class for unified CXR instruction tuning. Handles 3 tasks: findings generation, impression generation, VQA. Data sources: - MIMIC-CXR (physionet.org/content/mimic-cxr/2.1.0) └── reports/*.txt → findings + impression labels - MIMIC-CXR-JPG (physionet.org/content/mimic-cxr-jpg/2.0.0) └── files/**/*.jpg → images - MIMIC-Ext-CXR-QBA (physionet.org/content/mimic-ext-cxr-qba/1.0.0) └── *.json → VQA pairs NOTE: Data loading implementation left as TODO — fill in paths and JSON structure once data is downloaded. """ import json import random from pathlib import Path from typing import Optional, Dict, List, Literal import torch from torch.utils.data import Dataset from PIL import Image from .prompt_templates import build_training_sample from model.rad_dino import BioViLTEncoder TaskType = Literal["findings", "impression", "report", "vqa", "mixed"] # ─── Output format helpers (used by both builders and evaluation) ─────────── REPORT_FINDINGS_TAG = "Findings:" REPORT_IMPRESSION_TAG = "Impression:" def format_merged_report(findings: Optional[str], impression: Optional[str]) -> Optional[str]: """ Build the target string for the merged "report" task. Returns: "Findings: \n\nImpression: " if both sections exist "Impression: " if only impression exists None if neither (or only findings — see below) Note: we deliberately drop samples that have ONLY findings. The merged task is meant to teach the model to produce a full report ending with an impression. Samples without impression would push the model to omit the impression section, which is exactly what we do not want. """ has_f = findings and findings.strip() has_i = impression and impression.strip() if has_f and has_i: return f"{REPORT_FINDINGS_TAG} {findings.strip()}\n\n{REPORT_IMPRESSION_TAG} {impression.strip()}" if has_i and not has_f: return f"{REPORT_IMPRESSION_TAG} {impression.strip()}" return None def parse_generated_report(text: str) -> Dict[str, str]: """ Split a generated/ground-truth report string into its two sections. Tolerant to missing sections, extra whitespace, and case variations of the tag. Returns: {"findings": , "impression": } — either may be "". """ import re f_match = re.search(r"findings\s*:\s*(.*?)(?=impression\s*:|$)", text, flags=re.IGNORECASE | re.DOTALL) i_match = re.search(r"impression\s*:\s*(.*)", text, flags=re.IGNORECASE | re.DOTALL) return { "findings": f_match.group(1).strip() if f_match else "", "impression": i_match.group(1).strip() if i_match else "", } class CXRInstructDataset(Dataset): """ Unified instruction-following dataset for chest X-ray interpretation. Each sample contains: image: chest X-ray tensor (C, H, W) input_ids: tokenized prompt attention_mask labels: tokenized target (prompt tokens masked with -100) Args: data_path: path to pre-built instruction JSON (see build_instruct_json.py) image_root: root directory of MIMIC-CXR-JPG images tokenizer: HuggingFace tokenizer transform: image transform (BioViLTEncoder.get_transform()) task: "findings" | "impression" | "vqa" | "mixed" split: "train" | "validate" | "test" cutoff_len: maximum token length task_weights: sampling weights for mixed task {"findings": 0.4, ...} """ def __init__( self, data_path: str, image_root: str, tokenizer, transform = None, task: TaskType = "mixed", split: str = "train", cutoff_len: int = 512, task_weights: Optional[Dict[str, float]] = None, max_images: int = 1, # >1 only useful in multi_image_merged mode feature_cache_dir: Optional[str] = None, itc_mode: bool = False, # Stage-1 ITC: yield (image, text_embed) itc_text_cache: Optional[str] = None, # {study_id: tensor[proj_dim]} .pt file ): self.image_root = Path(image_root) self.tokenizer = tokenizer self.transform = transform self.task = task self.split = split self.cutoff_len = cutoff_len self.max_images = max(1, int(max_images)) # ── ITC Stage-1 mode ──────────────────────────────────────────────── # Yields {image, text_embed} instead of tokenized prompt/target. The # text embeddings are PRECOMPUTED (CXR-BERT) and keyed by study_id; the # dataset is deduped to ONE image per study so every text embedding is # a unique InfoNCE positive (no in-batch collisions). self.itc_mode = bool(itc_mode) self.itc_text_embeds: Dict[str, torch.Tensor] = {} if self.itc_mode: if not itc_text_cache: raise ValueError("itc_mode=True requires itc_text_cache (.pt path).") raw = torch.load(itc_text_cache, map_location="cpu") # accept either {study_id: tensor} or {"embeds": {...}} wrappers self.itc_text_embeds = raw.get("embeds", raw) if isinstance(raw, dict) else raw print(f"[CXRInstructDataset] ITC mode: loaded " f"{len(self.itc_text_embeds):,} text embeddings ← {itc_text_cache}") # When set, _load_image first checks {feature_cache_dir}/{relpath}.pt # and returns the cached (P, 768) patch-feature tensor instead of the # raw image. The model detects this by tensor last-dim and skips the # frozen encoder entirely. Safe because no random augmentation is # applied (Resize + ToTensor + Normalize are deterministic). self.feature_cache_dir = Path(feature_cache_dir) if feature_cache_dir else None self.task_weights = task_weights or { "findings": 0.4, "impression": 0.2, "vqa": 0.4, } # Load pre-built instruction JSON # Format: list of dicts with keys: # image_path, task, target, question (for VQA), structured_findings self.samples = self._load_data(data_path, split) print(f"[{split}] Loaded {len(self.samples)} samples for task={task}") def _load_data(self, data_path: str, split: str) -> List[Dict]: """Load and filter samples by split and task type.""" with open(data_path, "r") as f: all_samples = json.load(f) # Filter by split samples = [s for s in all_samples if s.get("split", "train") == split] # ── ITC mode: dedup to ONE image per study with an available text # embedding. Each study's canonical text (findings-first, computed at # precompute time) is a unique InfoNCE positive. ── if getattr(self, "itc_mode", False): seen_studies = set() itc_samples: List[Dict] = [] n_no_embed = n_no_study = 0 for s in samples: # study_id when present (MIMIC); else fall back to image_path # so datasets without study ids (IU-Xray) still smoke-test. sid = s.get("study_id") or s.get("image_path") if not sid: n_no_study += 1 continue if sid in seen_studies: continue if sid not in self.itc_text_embeds: n_no_embed += 1 continue # skip merged/cached-feature-incompatible? keep single-image only if not s.get("image_path") and not s.get("image_paths"): continue seen_studies.add(sid) itc_samples.append(s) print(f"[CXRInstructDataset] ITC[{split}]: {len(itc_samples):,} studies " f"(dropped {n_no_embed:,} w/o text-embed, {n_no_study:,} w/o study_id)") return itc_samples # Filter by task if self.task != "mixed": samples = [s for s in samples if s["task"] == self.task] # Sanity check: report_mode is enforced at build time, but if the # caller asks for task=report yet the JSON was built in "split" mode # (or vice versa) the dataset will be empty / wrong. Warn loudly. if len(samples) == 0: tasks_in_file = sorted({s["task"] for s in all_samples if s.get("split", "train") == split}) print(f"[CXRInstructDataset] WARNING: 0 samples for task={self.task!r} " f"in split={split!r}. Tasks present in JSON: {tasks_in_file}. " f"Did you forget to rebuild the instruct JSON after changing " f"data.report_mode in train_config.yaml?") return samples def __len__(self) -> int: return len(self.samples) def get_per_sample_weights(self) -> Optional[List[float]]: """ Build per-sample weights for `torch.utils.data.WeightedRandomSampler` so that, in expectation, each task occupies its configured fraction of drawn training samples — regardless of how many samples of each task exist in the JSON. Math: For task t with N_t samples in the JSON and configured weight w_t, give every sample of t the weight `w_t / N_t`. The aggregate probability of drawing ANY sample of task t over one draw becomes `N_t * (w_t / N_t) = w_t`, which is exactly the desired ratio. Tasks with weight 0 (e.g. VQA on IU-Xray) get weight 0 → never drawn. Tasks present in the JSON but absent from `self.task_weights` also get weight 0 (loud-failure-on-misconfig is preferable to silent miscounts). Returns: list of floats of length len(self.samples), or None if this is a single-task dataset (`self.task != "mixed"`) — in that case every sample is the same task, so weighted sampling is unnecessary and the default uniform `RandomSampler` is correct. """ if getattr(self, "itc_mode", False): return None # ITC: uniform sampling over deduped studies if self.task != "mixed": return None # Count samples per task that actually appear in this dataset. counts: Dict[str, int] = {} for s in self.samples: counts[s["task"]] = counts.get(s["task"], 0) + 1 # Per-sample weight = w_task / N_task. Tasks not in task_weights → 0. weights = [ float(self.task_weights.get(s["task"], 0.0)) / counts[s["task"]] for s in self.samples ] # Sanity: print effective per-task probabilities once so the actual # mix during training is visible in logs (helps catch misconfigured # weights vs. JSON-task-set mismatch). eff = {t: float(self.task_weights.get(t, 0.0)) for t in counts} eff_sum = sum(eff.values()) or 1.0 eff = {t: round(v / eff_sum, 4) for t, v in eff.items()} print(f"[CXRInstructDataset] WeightedRandomSampler effective task mix: " f"{eff} (counts: {counts})") return weights def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: sample = self.samples[idx] # ── Load image(s) ─────────────────────────────────────────────────── # In single-image modes ("image_path" set), `image` is (C, H, W). # In multi-image mode ("image_paths" set), `image` is (max_images, C, H, W) # padded/truncated to self.max_images by duplicating the first view. if sample.get("image_paths"): image = self._load_image_stack(sample["image_paths"]) # (N, C, H, W) else: image = self._load_image(sample["image_path"]) # (C, H, W) # ── ITC mode: return (image, precomputed text embedding) ───────────── if getattr(self, "itc_mode", False): key = sample.get("study_id") or sample["image_path"] text_embed = self.itc_text_embeds[key].float() # (proj_dim,) return {"image": image, "text_embed": text_embed} # ── Build prompt + target ──────────────────────────────────────────── training_sample = build_training_sample( task = sample["task"], target = sample["target"], question = sample.get("question"), # VQA only structured_findings = sample.get("structured_findings"), # optional randomize = (self.split == "train"), # only randomize in train ) # ── Tokenize ───────────────────────────────────────────────────────── input_ids, labels = self._tokenize( prompt = training_sample["prompt"], target = training_sample["target"], ) # `input_ids` is un-padded at this point. The collator builds the # final attention_mask (1 for real tokens, 0 for batch-level # padding) after stacking. We pass an all-ones placeholder of the # correct per-sample length so other parts of the pipeline that # peek at the dataset output still see the expected key. return { "image": image, "input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": labels, "task": sample["task"], # for per-task logging } def _load_image_stack(self, image_paths: List[str]) -> torch.Tensor: """ Load a list of images and stack them into (max_images, C, H, W). Padding rule: if the study has fewer images than max_images, duplicate the FIRST view (frontal) until we hit max_images. Truncation rule: keep the first max_images views in the order the builder emitted. This keeps every multi-image sample the SAME tensor shape so the default torch.utils.data collator can batch them without a custom collator. The downside is wasted compute on duplicated views; if you need a per-sample N, replace with a custom collator that pads inside the batch. """ if not image_paths: raise ValueError(f"empty image_paths for sample") imgs = [self._load_image(p) for p in image_paths[: self.max_images]] while len(imgs) < self.max_images: imgs.append(imgs[0].clone()) # pad with frontal return torch.stack(imgs, dim=0) # (N, C, H, W) def _load_image(self, image_path: str) -> torch.Tensor: """ Load and transform a chest X-ray image, or load pre-computed patch features when a feature cache hits. Args: image_path: relative path from image_root e.g. "files/p10/p10000032/s50414267/02aa804e.jpg" Returns either (C, H, W) for raw images or (P, 768) when a cached feature file exists at `{feature_cache_dir}/{image_path}.pt`. """ # ── Fast path: pre-computed patch features ──────────────────────── if self.feature_cache_dir is not None: cache_path = self.feature_cache_dir / (image_path + ".pt") if cache_path.is_file(): # Load tensor — should be (P, 768) for a single image. return torch.load(cache_path, map_location="cpu", weights_only=True) # ── Slow path: read JPEG + transform ────────────────────────────── full_path = self.image_root / image_path image = Image.open(full_path).convert("RGB") if self.transform is not None: image = self.transform(image) else: # Fallback basic transform if BioViL-T transform not provided from torchvision import transforms fallback = transforms.Compose([ transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image = fallback(image) return image def _tokenize( self, prompt: str, target: str, ) -> tuple: """ Tokenize prompt+target. Labels have -100 for prompt tokens (so loss is only computed on target tokens). Returns UN-PADDED tensors of variable length ≤ cutoff_len. The CXRDataCollator pads them to the max length within each batch, so short batches (e.g. all-VQA) skip the padded compute entirely. Per-sample padding is no longer applied here; that masking lives in the collator now. Returns: input_ids: (L,) L ≤ cutoff_len labels: (L,) with -100 for prompt positions """ full_text = prompt + " " + target prompt_encoded = self.tokenizer.encode(prompt, add_special_tokens=True) full_encoded = self.tokenizer.encode( full_text, add_special_tokens = True, max_length = self.cutoff_len, truncation = True, # No padding — collator pads to max-in-batch (dynamic padding). ) input_ids = torch.tensor(full_encoded, dtype=torch.long) # Labels: -100 for prompt tokens, actual token ids for target tokens. # (Padding masking now happens in the collator since there's no # padding at this stage.) labels = input_ids.clone() prompt_len = min(len(prompt_encoded), self.cutoff_len) labels[:prompt_len] = -100 return input_ids, labels # ─── Helper: Build Instruction JSON ───────────────────────────────────────── def build_instruct_json( mimic_cxr_root: str, output_path: str, chexpert_csv: Optional[str] = None, vqa_data_root: Optional[str] = None, report_mode: str = "split", image_mode: str = "all_views_split", ) -> str: """ Build the unified MIMIC-CXR instruction JSON. Thin delegate to `data.mimic_cxr_builder.build_mimic_cxr_instruct_json`, which walks the pre-split MIMIC layout (train/valid/test), parses findings/impression from the report .txt files, and bakes the 14 CheXpert labels (oracle, from `*chexpert*.csv`) into `structured_findings` as the PNU 3-section string (U-MultiClass, META-CXR format) — the RaDialog image + abnormality-guidance setup. `report_mode` / `image_mode` mirror the IU builder. Output entries match the shared schema, e.g.: {"image_path": "train/p10/p10000032/s50414267/02aa804e.jpg", "task": "findings", "target": "The lungs are clear...", "question": null, "structured_findings": "Positive Abnormalities: None\\n Negative Abnormalities: No Finding, ...\\n Uncertain Abnormalities: None", "split": "train", "study_id": "s50414267", "subject_id": "p10000032"} """ from .mimic_cxr_builder import build_mimic_cxr_instruct_json return build_mimic_cxr_instruct_json( mimic_root = mimic_cxr_root, output_path = output_path, chexpert_csv = chexpert_csv, vqa_root = vqa_data_root, report_mode = report_mode, image_mode = image_mode, )