File size: 20,377 Bytes
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215ecd6
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61f01a
8356dae
 
28b13fc
 
 
 
 
 
 
 
8356dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61f01a
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8356dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b961b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8356dae
 
 
b961b41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
8356dae
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61f01a
 
 
 
 
28b13fc
 
 
c61f01a
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61f01a
 
28b13fc
 
 
 
c61f01a
 
 
28b13fc
c61f01a
 
 
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c61f01a
 
 
 
 
 
28b13fc
c61f01a
 
28b13fc
 
 
 
 
 
 
 
c61f01a
28b13fc
 
 
 
c61f01a
 
 
28b13fc
 
 
 
 
 
 
 
 
 
 
 
02426e6
 
 
 
 
28b13fc
02426e6
 
 
 
 
215ecd6
 
 
 
02426e6
 
 
 
 
215ecd6
 
 
02426e6
 
28b13fc
02426e6
 
215ecd6
 
 
 
 
 
28b13fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
"""
dataset.py
----------
Dataset class for unified CXR instruction tuning.
Handles 3 tasks: findings generation, impression generation, VQA.

Data sources:
  - MIMIC-CXR (physionet.org/content/mimic-cxr/2.1.0)
      └── reports/*.txt       → findings + impression labels
  - MIMIC-CXR-JPG (physionet.org/content/mimic-cxr-jpg/2.0.0)
      └── files/**/*.jpg      → images
  - MIMIC-Ext-CXR-QBA (physionet.org/content/mimic-ext-cxr-qba/1.0.0)
      └── *.json              → VQA pairs

NOTE: Data loading implementation left as TODO —
      fill in paths and JSON structure once data is downloaded.
"""

import json
import random
from pathlib import Path
from typing import Optional, Dict, List, Literal

import torch
from torch.utils.data import Dataset
from PIL import Image

from .prompt_templates import build_training_sample
from model.rad_dino import BioViLTEncoder


TaskType = Literal["findings", "impression", "report", "vqa", "mixed"]


# ─── Output format helpers (used by both builders and evaluation) ───────────

REPORT_FINDINGS_TAG   = "Findings:"
REPORT_IMPRESSION_TAG = "Impression:"


def format_merged_report(findings: Optional[str], impression: Optional[str]) -> Optional[str]:
    """
    Build the target string for the merged "report" task.

    Returns:
        "Findings: <f>\n\nImpression: <i>"   if both sections exist
        "Impression: <i>"                    if only impression exists
        None                                 if neither (or only findings — see below)

    Note: we deliberately drop samples that have ONLY findings. The merged
    task is meant to teach the model to produce a full report ending with
    an impression. Samples without impression would push the model to omit
    the impression section, which is exactly what we do not want.
    """
    has_f = findings   and findings.strip()
    has_i = impression and impression.strip()
    if has_f and has_i:
        return f"{REPORT_FINDINGS_TAG} {findings.strip()}\n\n{REPORT_IMPRESSION_TAG} {impression.strip()}"
    if has_i and not has_f:
        return f"{REPORT_IMPRESSION_TAG} {impression.strip()}"
    return None


def parse_generated_report(text: str) -> Dict[str, str]:
    """
    Split a generated/ground-truth report string into its two sections.
    Tolerant to missing sections, extra whitespace, and case variations of the tag.

    Returns:
        {"findings": <str>, "impression": <str>}  — either may be "".
    """
    import re
    f_match = re.search(r"findings\s*:\s*(.*?)(?=impression\s*:|$)", text,
                        flags=re.IGNORECASE | re.DOTALL)
    i_match = re.search(r"impression\s*:\s*(.*)", text,
                        flags=re.IGNORECASE | re.DOTALL)
    return {
        "findings":   f_match.group(1).strip() if f_match else "",
        "impression": i_match.group(1).strip() if i_match else "",
    }


class CXRInstructDataset(Dataset):
    """
    Unified instruction-following dataset for chest X-ray interpretation.

    Each sample contains:
        image:         chest X-ray tensor (C, H, W)
        input_ids:     tokenized prompt
        attention_mask
        labels:        tokenized target (prompt tokens masked with -100)

    Args:
        data_path:    path to pre-built instruction JSON (see build_instruct_json.py)
        image_root:   root directory of MIMIC-CXR-JPG images
        tokenizer:    HuggingFace tokenizer
        transform:    image transform (BioViLTEncoder.get_transform())
        task:         "findings" | "impression" | "vqa" | "mixed"
        split:        "train" | "validate" | "test"
        cutoff_len:   maximum token length
        task_weights: sampling weights for mixed task {"findings": 0.4, ...}
    """

    def __init__(
        self,
        data_path:    str,
        image_root:   str,
        tokenizer,
        transform     = None,
        task:         TaskType = "mixed",
        split:        str = "train",
        cutoff_len:   int = 512,
        task_weights: Optional[Dict[str, float]] = None,
        max_images:   int = 1,           # >1 only useful in multi_image_merged mode
        feature_cache_dir: Optional[str] = None,
        itc_mode:     bool = False,            # Stage-1 ITC: yield (image, text_embed)
        itc_text_cache: Optional[str] = None,  # {study_id: tensor[proj_dim]} .pt file
    ):
        self.image_root  = Path(image_root)
        self.tokenizer   = tokenizer
        self.transform   = transform
        self.task        = task
        self.split       = split
        self.cutoff_len  = cutoff_len
        self.max_images  = max(1, int(max_images))

        # ── ITC Stage-1 mode ────────────────────────────────────────────────
        # Yields {image, text_embed} instead of tokenized prompt/target. The
        # text embeddings are PRECOMPUTED (CXR-BERT) and keyed by study_id; the
        # dataset is deduped to ONE image per study so every text embedding is
        # a unique InfoNCE positive (no in-batch collisions).
        self.itc_mode = bool(itc_mode)
        self.itc_text_embeds: Dict[str, torch.Tensor] = {}
        if self.itc_mode:
            if not itc_text_cache:
                raise ValueError("itc_mode=True requires itc_text_cache (.pt path).")
            raw = torch.load(itc_text_cache, map_location="cpu")
            # accept either {study_id: tensor} or {"embeds": {...}} wrappers
            self.itc_text_embeds = raw.get("embeds", raw) if isinstance(raw, dict) else raw
            print(f"[CXRInstructDataset] ITC mode: loaded "
                  f"{len(self.itc_text_embeds):,} text embeddings ← {itc_text_cache}")
        # When set, _load_image first checks {feature_cache_dir}/{relpath}.pt
        # and returns the cached (P, 768) patch-feature tensor instead of the
        # raw image. The model detects this by tensor last-dim and skips the
        # frozen encoder entirely. Safe because no random augmentation is
        # applied (Resize + ToTensor + Normalize are deterministic).
        self.feature_cache_dir = Path(feature_cache_dir) if feature_cache_dir else None

        self.task_weights = task_weights or {
            "findings":   0.4,
            "impression": 0.2,
            "vqa":        0.4,
        }

        # Load pre-built instruction JSON
        # Format: list of dicts with keys:
        #   image_path, task, target, question (for VQA), structured_findings
        self.samples = self._load_data(data_path, split)
        print(f"[{split}] Loaded {len(self.samples)} samples for task={task}")

    def _load_data(self, data_path: str, split: str) -> List[Dict]:
        """Load and filter samples by split and task type."""
        with open(data_path, "r") as f:
            all_samples = json.load(f)

        # Filter by split
        samples = [s for s in all_samples if s.get("split", "train") == split]

        # ── ITC mode: dedup to ONE image per study with an available text
        # embedding. Each study's canonical text (findings-first, computed at
        # precompute time) is a unique InfoNCE positive. ──
        if getattr(self, "itc_mode", False):
            seen_studies = set()
            itc_samples: List[Dict] = []
            n_no_embed = n_no_study = 0
            for s in samples:
                # study_id when present (MIMIC); else fall back to image_path
                # so datasets without study ids (IU-Xray) still smoke-test.
                sid = s.get("study_id") or s.get("image_path")
                if not sid:
                    n_no_study += 1
                    continue
                if sid in seen_studies:
                    continue
                if sid not in self.itc_text_embeds:
                    n_no_embed += 1
                    continue
                # skip merged/cached-feature-incompatible? keep single-image only
                if not s.get("image_path") and not s.get("image_paths"):
                    continue
                seen_studies.add(sid)
                itc_samples.append(s)
            print(f"[CXRInstructDataset] ITC[{split}]: {len(itc_samples):,} studies "
                  f"(dropped {n_no_embed:,} w/o text-embed, {n_no_study:,} w/o study_id)")
            return itc_samples

        # Filter by task
        if self.task != "mixed":
            samples = [s for s in samples if s["task"] == self.task]

        # Sanity check: report_mode is enforced at build time, but if the
        # caller asks for task=report yet the JSON was built in "split" mode
        # (or vice versa) the dataset will be empty / wrong. Warn loudly.
        if len(samples) == 0:
            tasks_in_file = sorted({s["task"] for s in all_samples
                                    if s.get("split", "train") == split})
            print(f"[CXRInstructDataset] WARNING: 0 samples for task={self.task!r} "
                  f"in split={split!r}. Tasks present in JSON: {tasks_in_file}. "
                  f"Did you forget to rebuild the instruct JSON after changing "
                  f"data.report_mode in train_config.yaml?")

        return samples

    def __len__(self) -> int:
        return len(self.samples)

    def get_per_sample_weights(self) -> Optional[List[float]]:
        """
        Build per-sample weights for `torch.utils.data.WeightedRandomSampler`
        so that, in expectation, each task occupies its configured fraction of
        drawn training samples — regardless of how many samples of each task
        exist in the JSON.

        Math:
            For task t with N_t samples in the JSON and configured weight w_t,
            give every sample of t the weight `w_t / N_t`. The aggregate
            probability of drawing ANY sample of task t over one draw becomes
            `N_t * (w_t / N_t) = w_t`, which is exactly the desired ratio.

        Tasks with weight 0 (e.g. VQA on IU-Xray) get weight 0 → never drawn.
        Tasks present in the JSON but absent from `self.task_weights` also get
        weight 0 (loud-failure-on-misconfig is preferable to silent miscounts).

        Returns:
            list of floats of length len(self.samples), or None if this is a
            single-task dataset (`self.task != "mixed"`) — in that case every
            sample is the same task, so weighted sampling is unnecessary and
            the default uniform `RandomSampler` is correct.
        """
        if getattr(self, "itc_mode", False):
            return None      # ITC: uniform sampling over deduped studies

        if self.task != "mixed":
            return None

        # Count samples per task that actually appear in this dataset.
        counts: Dict[str, int] = {}
        for s in self.samples:
            counts[s["task"]] = counts.get(s["task"], 0) + 1

        # Per-sample weight = w_task / N_task. Tasks not in task_weights → 0.
        weights = [
            float(self.task_weights.get(s["task"], 0.0)) / counts[s["task"]]
            for s in self.samples
        ]

        # Sanity: print effective per-task probabilities once so the actual
        # mix during training is visible in logs (helps catch misconfigured
        # weights vs. JSON-task-set mismatch).
        eff = {t: float(self.task_weights.get(t, 0.0)) for t in counts}
        eff_sum = sum(eff.values()) or 1.0
        eff = {t: round(v / eff_sum, 4) for t, v in eff.items()}
        print(f"[CXRInstructDataset] WeightedRandomSampler effective task mix: "
              f"{eff}  (counts: {counts})")
        return weights

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.samples[idx]

        # ── Load image(s) ───────────────────────────────────────────────────
        # In single-image modes ("image_path" set), `image` is (C, H, W).
        # In multi-image mode ("image_paths" set), `image` is (max_images, C, H, W)
        # padded/truncated to self.max_images by duplicating the first view.
        if sample.get("image_paths"):
            image = self._load_image_stack(sample["image_paths"])  # (N, C, H, W)
        else:
            image = self._load_image(sample["image_path"])         # (C, H, W)

        # ── ITC mode: return (image, precomputed text embedding) ─────────────
        if getattr(self, "itc_mode", False):
            key = sample.get("study_id") or sample["image_path"]
            text_embed = self.itc_text_embeds[key].float()                 # (proj_dim,)
            return {"image": image, "text_embed": text_embed}

        # ── Build prompt + target ────────────────────────────────────────────
        training_sample = build_training_sample(
            task                = sample["task"],
            target              = sample["target"],
            question            = sample.get("question"),          # VQA only
            structured_findings = sample.get("structured_findings"),  # optional
            randomize           = (self.split == "train"),         # only randomize in train
        )

        # ── Tokenize ─────────────────────────────────────────────────────────
        input_ids, labels = self._tokenize(
            prompt = training_sample["prompt"],
            target = training_sample["target"],
        )

        # `input_ids` is un-padded at this point. The collator builds the
        # final attention_mask (1 for real tokens, 0 for batch-level
        # padding) after stacking. We pass an all-ones placeholder of the
        # correct per-sample length so other parts of the pipeline that
        # peek at the dataset output still see the expected key.
        return {
            "image":          image,
            "input_ids":      input_ids,
            "attention_mask": torch.ones_like(input_ids),
            "labels":         labels,
            "task":           sample["task"],   # for per-task logging
        }

    def _load_image_stack(self, image_paths: List[str]) -> torch.Tensor:
        """
        Load a list of images and stack them into (max_images, C, H, W).

        Padding rule: if the study has fewer images than max_images, duplicate
        the FIRST view (frontal) until we hit max_images. Truncation rule:
        keep the first max_images views in the order the builder emitted.

        This keeps every multi-image sample the SAME tensor shape so the
        default torch.utils.data collator can batch them without a custom
        collator. The downside is wasted compute on duplicated views; if you
        need a per-sample N, replace with a custom collator that pads inside
        the batch.
        """
        if not image_paths:
            raise ValueError(f"empty image_paths for sample")
        imgs = [self._load_image(p) for p in image_paths[: self.max_images]]
        while len(imgs) < self.max_images:
            imgs.append(imgs[0].clone())   # pad with frontal
        return torch.stack(imgs, dim=0)    # (N, C, H, W)


    def _load_image(self, image_path: str) -> torch.Tensor:
        """
        Load and transform a chest X-ray image, or load pre-computed patch
        features when a feature cache hits.

        Args:
            image_path: relative path from image_root
                        e.g. "files/p10/p10000032/s50414267/02aa804e.jpg"

        Returns either (C, H, W) for raw images or (P, 768) when a cached
        feature file exists at `{feature_cache_dir}/{image_path}.pt`.
        """
        # ── Fast path: pre-computed patch features ────────────────────────
        if self.feature_cache_dir is not None:
            cache_path = self.feature_cache_dir / (image_path + ".pt")
            if cache_path.is_file():
                # Load tensor — should be (P, 768) for a single image.
                return torch.load(cache_path, map_location="cpu", weights_only=True)

        # ── Slow path: read JPEG + transform ──────────────────────────────
        full_path = self.image_root / image_path
        image = Image.open(full_path).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)
        else:
            # Fallback basic transform if BioViL-T transform not provided
            from torchvision import transforms
            fallback = transforms.Compose([
                transforms.Resize((448, 448)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])
            image = fallback(image)

        return image

    def _tokenize(
        self,
        prompt: str,
        target: str,
    ) -> tuple:
        """
        Tokenize prompt+target. Labels have -100 for prompt tokens
        (so loss is only computed on target tokens).

        Returns UN-PADDED tensors of variable length ≤ cutoff_len. The
        CXRDataCollator pads them to the max length within each batch,
        so short batches (e.g. all-VQA) skip the padded compute entirely.
        Per-sample padding is no longer applied here; that masking lives
        in the collator now.

        Returns:
            input_ids: (L,)    L ≤ cutoff_len
            labels:    (L,)    with -100 for prompt positions
        """
        full_text      = prompt + " " + target
        prompt_encoded = self.tokenizer.encode(prompt, add_special_tokens=True)
        full_encoded   = self.tokenizer.encode(
            full_text,
            add_special_tokens  = True,
            max_length          = self.cutoff_len,
            truncation          = True,
            # No padding — collator pads to max-in-batch (dynamic padding).
        )

        input_ids = torch.tensor(full_encoded, dtype=torch.long)

        # Labels: -100 for prompt tokens, actual token ids for target tokens.
        # (Padding masking now happens in the collator since there's no
        # padding at this stage.)
        labels = input_ids.clone()
        prompt_len = min(len(prompt_encoded), self.cutoff_len)
        labels[:prompt_len] = -100

        return input_ids, labels


# ─── Helper: Build Instruction JSON ─────────────────────────────────────────

def build_instruct_json(
    mimic_cxr_root:  str,
    output_path:     str,
    chexpert_csv:    Optional[str] = None,
    vqa_data_root:   Optional[str] = None,
    report_mode:     str = "split",
    image_mode:      str = "all_views_split",
) -> str:
    """
    Build the unified MIMIC-CXR instruction JSON.

    Thin delegate to `data.mimic_cxr_builder.build_mimic_cxr_instruct_json`,
    which walks the pre-split MIMIC layout (train/valid/test), parses
    findings/impression from the report .txt files, and bakes the 14 CheXpert
    labels (oracle, from `*chexpert*.csv`) into `structured_findings` as the
    PNU 3-section string (U-MultiClass, META-CXR format) — the RaDialog
    image + abnormality-guidance setup. `report_mode` / `image_mode` mirror
    the IU builder.

    Output entries match the shared schema, e.g.:
        {"image_path": "train/p10/p10000032/s50414267/02aa804e.jpg",
         "task": "findings", "target": "The lungs are clear...",
         "question": null,
         "structured_findings": "Positive Abnormalities: None\\n
             Negative Abnormalities: No Finding, ...\\n
             Uncertain Abnormalities: None",
         "split": "train", "study_id": "s50414267",
         "subject_id": "p10000032"}
    """
    from .mimic_cxr_builder import build_mimic_cxr_instruct_json
    return build_mimic_cxr_instruct_json(
        mimic_root   = mimic_cxr_root,
        output_path  = output_path,
        chexpert_csv = chexpert_csv,
        vqa_root     = vqa_data_root,
        report_mode  = report_mode,
        image_mode   = image_mode,
    )