| """ |
| dataset.py |
| ---------- |
| Dataset class for unified CXR instruction tuning. |
| Handles 3 tasks: findings generation, impression generation, VQA. |
| |
| Data sources: |
| - MIMIC-CXR (physionet.org/content/mimic-cxr/2.1.0) |
| └── reports/*.txt → findings + impression labels |
| - MIMIC-CXR-JPG (physionet.org/content/mimic-cxr-jpg/2.0.0) |
| └── files/**/*.jpg → images |
| - MIMIC-Ext-CXR-QBA (physionet.org/content/mimic-ext-cxr-qba/1.0.0) |
| └── *.json → VQA pairs |
| |
| NOTE: Data loading implementation left as TODO — |
| fill in paths and JSON structure once data is downloaded. |
| """ |
|
|
| import json |
| import random |
| from pathlib import Path |
| from typing import Optional, Dict, List, Literal |
|
|
| import torch |
| from torch.utils.data import Dataset |
| from PIL import Image |
|
|
| from .prompt_templates import build_training_sample |
| from model.rad_dino import BioViLTEncoder |
|
|
|
|
| TaskType = Literal["findings", "impression", "report", "vqa", "mixed"] |
|
|
|
|
| |
|
|
| REPORT_FINDINGS_TAG = "Findings:" |
| REPORT_IMPRESSION_TAG = "Impression:" |
|
|
|
|
| def format_merged_report(findings: Optional[str], impression: Optional[str]) -> Optional[str]: |
| """ |
| Build the target string for the merged "report" task. |
| |
| Returns: |
| "Findings: <f>\n\nImpression: <i>" if both sections exist |
| "Impression: <i>" if only impression exists |
| None if neither (or only findings — see below) |
| |
| Note: we deliberately drop samples that have ONLY findings. The merged |
| task is meant to teach the model to produce a full report ending with |
| an impression. Samples without impression would push the model to omit |
| the impression section, which is exactly what we do not want. |
| """ |
| has_f = findings and findings.strip() |
| has_i = impression and impression.strip() |
| if has_f and has_i: |
| return f"{REPORT_FINDINGS_TAG} {findings.strip()}\n\n{REPORT_IMPRESSION_TAG} {impression.strip()}" |
| if has_i and not has_f: |
| return f"{REPORT_IMPRESSION_TAG} {impression.strip()}" |
| return None |
|
|
|
|
| def parse_generated_report(text: str) -> Dict[str, str]: |
| """ |
| Split a generated/ground-truth report string into its two sections. |
| Tolerant to missing sections, extra whitespace, and case variations of the tag. |
| |
| Returns: |
| {"findings": <str>, "impression": <str>} — either may be "". |
| """ |
| import re |
| f_match = re.search(r"findings\s*:\s*(.*?)(?=impression\s*:|$)", text, |
| flags=re.IGNORECASE | re.DOTALL) |
| i_match = re.search(r"impression\s*:\s*(.*)", text, |
| flags=re.IGNORECASE | re.DOTALL) |
| return { |
| "findings": f_match.group(1).strip() if f_match else "", |
| "impression": i_match.group(1).strip() if i_match else "", |
| } |
|
|
|
|
| class CXRInstructDataset(Dataset): |
| """ |
| Unified instruction-following dataset for chest X-ray interpretation. |
| |
| Each sample contains: |
| image: chest X-ray tensor (C, H, W) |
| input_ids: tokenized prompt |
| attention_mask |
| labels: tokenized target (prompt tokens masked with -100) |
| |
| Args: |
| data_path: path to pre-built instruction JSON (see build_instruct_json.py) |
| image_root: root directory of MIMIC-CXR-JPG images |
| tokenizer: HuggingFace tokenizer |
| transform: image transform (BioViLTEncoder.get_transform()) |
| task: "findings" | "impression" | "vqa" | "mixed" |
| split: "train" | "validate" | "test" |
| cutoff_len: maximum token length |
| task_weights: sampling weights for mixed task {"findings": 0.4, ...} |
| """ |
|
|
| def __init__( |
| self, |
| data_path: str, |
| image_root: str, |
| tokenizer, |
| transform = None, |
| task: TaskType = "mixed", |
| split: str = "train", |
| cutoff_len: int = 512, |
| task_weights: Optional[Dict[str, float]] = None, |
| max_images: int = 1, |
| feature_cache_dir: Optional[str] = None, |
| itc_mode: bool = False, |
| itc_text_cache: Optional[str] = None, |
| ): |
| self.image_root = Path(image_root) |
| self.tokenizer = tokenizer |
| self.transform = transform |
| self.task = task |
| self.split = split |
| self.cutoff_len = cutoff_len |
| self.max_images = max(1, int(max_images)) |
|
|
| |
| |
| |
| |
| |
| self.itc_mode = bool(itc_mode) |
| self.itc_text_embeds: Dict[str, torch.Tensor] = {} |
| if self.itc_mode: |
| if not itc_text_cache: |
| raise ValueError("itc_mode=True requires itc_text_cache (.pt path).") |
| raw = torch.load(itc_text_cache, map_location="cpu") |
| |
| self.itc_text_embeds = raw.get("embeds", raw) if isinstance(raw, dict) else raw |
| print(f"[CXRInstructDataset] ITC mode: loaded " |
| f"{len(self.itc_text_embeds):,} text embeddings ← {itc_text_cache}") |
| |
| |
| |
| |
| |
| self.feature_cache_dir = Path(feature_cache_dir) if feature_cache_dir else None |
|
|
| self.task_weights = task_weights or { |
| "findings": 0.4, |
| "impression": 0.2, |
| "vqa": 0.4, |
| } |
|
|
| |
| |
| |
| self.samples = self._load_data(data_path, split) |
| print(f"[{split}] Loaded {len(self.samples)} samples for task={task}") |
|
|
| def _load_data(self, data_path: str, split: str) -> List[Dict]: |
| """Load and filter samples by split and task type.""" |
| with open(data_path, "r") as f: |
| all_samples = json.load(f) |
|
|
| |
| samples = [s for s in all_samples if s.get("split", "train") == split] |
|
|
| |
| |
| |
| if getattr(self, "itc_mode", False): |
| seen_studies = set() |
| itc_samples: List[Dict] = [] |
| n_no_embed = n_no_study = 0 |
| for s in samples: |
| |
| |
| sid = s.get("study_id") or s.get("image_path") |
| if not sid: |
| n_no_study += 1 |
| continue |
| if sid in seen_studies: |
| continue |
| if sid not in self.itc_text_embeds: |
| n_no_embed += 1 |
| continue |
| |
| if not s.get("image_path") and not s.get("image_paths"): |
| continue |
| seen_studies.add(sid) |
| itc_samples.append(s) |
| print(f"[CXRInstructDataset] ITC[{split}]: {len(itc_samples):,} studies " |
| f"(dropped {n_no_embed:,} w/o text-embed, {n_no_study:,} w/o study_id)") |
| return itc_samples |
|
|
| |
| if self.task != "mixed": |
| samples = [s for s in samples if s["task"] == self.task] |
|
|
| |
| |
| |
| if len(samples) == 0: |
| tasks_in_file = sorted({s["task"] for s in all_samples |
| if s.get("split", "train") == split}) |
| print(f"[CXRInstructDataset] WARNING: 0 samples for task={self.task!r} " |
| f"in split={split!r}. Tasks present in JSON: {tasks_in_file}. " |
| f"Did you forget to rebuild the instruct JSON after changing " |
| f"data.report_mode in train_config.yaml?") |
|
|
| return samples |
|
|
| def __len__(self) -> int: |
| return len(self.samples) |
|
|
| def get_per_sample_weights(self) -> Optional[List[float]]: |
| """ |
| Build per-sample weights for `torch.utils.data.WeightedRandomSampler` |
| so that, in expectation, each task occupies its configured fraction of |
| drawn training samples — regardless of how many samples of each task |
| exist in the JSON. |
| |
| Math: |
| For task t with N_t samples in the JSON and configured weight w_t, |
| give every sample of t the weight `w_t / N_t`. The aggregate |
| probability of drawing ANY sample of task t over one draw becomes |
| `N_t * (w_t / N_t) = w_t`, which is exactly the desired ratio. |
| |
| Tasks with weight 0 (e.g. VQA on IU-Xray) get weight 0 → never drawn. |
| Tasks present in the JSON but absent from `self.task_weights` also get |
| weight 0 (loud-failure-on-misconfig is preferable to silent miscounts). |
| |
| Returns: |
| list of floats of length len(self.samples), or None if this is a |
| single-task dataset (`self.task != "mixed"`) — in that case every |
| sample is the same task, so weighted sampling is unnecessary and |
| the default uniform `RandomSampler` is correct. |
| """ |
| if getattr(self, "itc_mode", False): |
| return None |
|
|
| if self.task != "mixed": |
| return None |
|
|
| |
| counts: Dict[str, int] = {} |
| for s in self.samples: |
| counts[s["task"]] = counts.get(s["task"], 0) + 1 |
|
|
| |
| weights = [ |
| float(self.task_weights.get(s["task"], 0.0)) / counts[s["task"]] |
| for s in self.samples |
| ] |
|
|
| |
| |
| |
| eff = {t: float(self.task_weights.get(t, 0.0)) for t in counts} |
| eff_sum = sum(eff.values()) or 1.0 |
| eff = {t: round(v / eff_sum, 4) for t, v in eff.items()} |
| print(f"[CXRInstructDataset] WeightedRandomSampler effective task mix: " |
| f"{eff} (counts: {counts})") |
| return weights |
|
|
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| sample = self.samples[idx] |
|
|
| |
| |
| |
| |
| if sample.get("image_paths"): |
| image = self._load_image_stack(sample["image_paths"]) |
| else: |
| image = self._load_image(sample["image_path"]) |
|
|
| |
| if getattr(self, "itc_mode", False): |
| key = sample.get("study_id") or sample["image_path"] |
| text_embed = self.itc_text_embeds[key].float() |
| return {"image": image, "text_embed": text_embed} |
|
|
| |
| training_sample = build_training_sample( |
| task = sample["task"], |
| target = sample["target"], |
| question = sample.get("question"), |
| structured_findings = sample.get("structured_findings"), |
| randomize = (self.split == "train"), |
| ) |
|
|
| |
| input_ids, labels = self._tokenize( |
| prompt = training_sample["prompt"], |
| target = training_sample["target"], |
| ) |
|
|
| |
| |
| |
| |
| |
| return { |
| "image": image, |
| "input_ids": input_ids, |
| "attention_mask": torch.ones_like(input_ids), |
| "labels": labels, |
| "task": sample["task"], |
| } |
|
|
| def _load_image_stack(self, image_paths: List[str]) -> torch.Tensor: |
| """ |
| Load a list of images and stack them into (max_images, C, H, W). |
| |
| Padding rule: if the study has fewer images than max_images, duplicate |
| the FIRST view (frontal) until we hit max_images. Truncation rule: |
| keep the first max_images views in the order the builder emitted. |
| |
| This keeps every multi-image sample the SAME tensor shape so the |
| default torch.utils.data collator can batch them without a custom |
| collator. The downside is wasted compute on duplicated views; if you |
| need a per-sample N, replace with a custom collator that pads inside |
| the batch. |
| """ |
| if not image_paths: |
| raise ValueError(f"empty image_paths for sample") |
| imgs = [self._load_image(p) for p in image_paths[: self.max_images]] |
| while len(imgs) < self.max_images: |
| imgs.append(imgs[0].clone()) |
| return torch.stack(imgs, dim=0) |
|
|
|
|
| def _load_image(self, image_path: str) -> torch.Tensor: |
| """ |
| Load and transform a chest X-ray image, or load pre-computed patch |
| features when a feature cache hits. |
| |
| Args: |
| image_path: relative path from image_root |
| e.g. "files/p10/p10000032/s50414267/02aa804e.jpg" |
| |
| Returns either (C, H, W) for raw images or (P, 768) when a cached |
| feature file exists at `{feature_cache_dir}/{image_path}.pt`. |
| """ |
| |
| if self.feature_cache_dir is not None: |
| cache_path = self.feature_cache_dir / (image_path + ".pt") |
| if cache_path.is_file(): |
| |
| return torch.load(cache_path, map_location="cpu", weights_only=True) |
|
|
| |
| full_path = self.image_root / image_path |
| image = Image.open(full_path).convert("RGB") |
|
|
| if self.transform is not None: |
| image = self.transform(image) |
| else: |
| |
| from torchvision import transforms |
| fallback = transforms.Compose([ |
| transforms.Resize((448, 448)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| std=[0.229, 0.224, 0.225]), |
| ]) |
| image = fallback(image) |
|
|
| return image |
|
|
| def _tokenize( |
| self, |
| prompt: str, |
| target: str, |
| ) -> tuple: |
| """ |
| Tokenize prompt+target. Labels have -100 for prompt tokens |
| (so loss is only computed on target tokens). |
| |
| Returns UN-PADDED tensors of variable length ≤ cutoff_len. The |
| CXRDataCollator pads them to the max length within each batch, |
| so short batches (e.g. all-VQA) skip the padded compute entirely. |
| Per-sample padding is no longer applied here; that masking lives |
| in the collator now. |
| |
| Returns: |
| input_ids: (L,) L ≤ cutoff_len |
| labels: (L,) with -100 for prompt positions |
| """ |
| full_text = prompt + " " + target |
| prompt_encoded = self.tokenizer.encode(prompt, add_special_tokens=True) |
| full_encoded = self.tokenizer.encode( |
| full_text, |
| add_special_tokens = True, |
| max_length = self.cutoff_len, |
| truncation = True, |
| |
| ) |
|
|
| input_ids = torch.tensor(full_encoded, dtype=torch.long) |
|
|
| |
| |
| |
| labels = input_ids.clone() |
| prompt_len = min(len(prompt_encoded), self.cutoff_len) |
| labels[:prompt_len] = -100 |
|
|
| return input_ids, labels |
|
|
|
|
| |
|
|
| def build_instruct_json( |
| mimic_cxr_root: str, |
| output_path: str, |
| chexpert_csv: Optional[str] = None, |
| vqa_data_root: Optional[str] = None, |
| report_mode: str = "split", |
| image_mode: str = "all_views_split", |
| ) -> str: |
| """ |
| Build the unified MIMIC-CXR instruction JSON. |
| |
| Thin delegate to `data.mimic_cxr_builder.build_mimic_cxr_instruct_json`, |
| which walks the pre-split MIMIC layout (train/valid/test), parses |
| findings/impression from the report .txt files, and bakes the 14 CheXpert |
| labels (oracle, from `*chexpert*.csv`) into `structured_findings` as the |
| PNU 3-section string (U-MultiClass, META-CXR format) — the RaDialog |
| image + abnormality-guidance setup. `report_mode` / `image_mode` mirror |
| the IU builder. |
| |
| Output entries match the shared schema, e.g.: |
| {"image_path": "train/p10/p10000032/s50414267/02aa804e.jpg", |
| "task": "findings", "target": "The lungs are clear...", |
| "question": null, |
| "structured_findings": "Positive Abnormalities: None\\n |
| Negative Abnormalities: No Finding, ...\\n |
| Uncertain Abnormalities: None", |
| "split": "train", "study_id": "s50414267", |
| "subject_id": "p10000032"} |
| """ |
| from .mimic_cxr_builder import build_mimic_cxr_instruct_json |
| return build_mimic_cxr_instruct_json( |
| mimic_root = mimic_cxr_root, |
| output_path = output_path, |
| chexpert_csv = chexpert_csv, |
| vqa_root = vqa_data_root, |
| report_mode = report_mode, |
| image_mode = image_mode, |
| ) |
|
|