File size: 2,789 Bytes
28b13fc
 
 
 
 
c61f01a
 
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
c61f01a
28b13fc
c61f01a
 
 
28b13fc
 
 
 
c61f01a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28b13fc
 
c61f01a
 
 
 
28b13fc
8356dae
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
collator.py
-----------
Custom DataCollator that handles variable-length sequences and
stacks images into a batch tensor.

Dynamic padding: input_ids / labels / attention_mask are padded to the
maximum length WITHIN EACH BATCH, not to `cutoff_len`. Batches drawn
mostly from short tasks (e.g. VQA) skip the wasted compute on padded
positions entirely — Llama still runs every matmul on every position,
so cutting off the padding tail is a direct FLOP saving (~1.5–2× on
this project's task mix; see commit notes).
"""

from dataclasses import dataclass
from typing import List, Dict
import torch


@dataclass
class CXRDataCollator:
    """
    Collates a list of dataset samples into a batch with **dynamic padding**.

    Args:
        pad_token_id: token id used to pad input_ids. labels are padded
                      with -100 (HF Trainer's ignore index for cross-entropy).
    """
    pad_token_id: int = 0

    def __call__(self, samples: List[Dict]) -> Dict[str, torch.Tensor]:
        # Images have fixed shape per dataset config (single-image (C,H,W),
        # multi-image (N,C,H,W), or cached features (P,D)/(N,P,D)) — torch.stack
        # works for any of them.
        images = torch.stack([s["image"] for s in samples])

        # ── Dynamic text padding ─────────────────────────────────────────
        max_len = max(s["input_ids"].size(0) for s in samples)
        B       = len(samples)

        input_ids      = torch.full((B, max_len), self.pad_token_id, dtype=torch.long)
        attention_mask = torch.zeros((B, max_len),                    dtype=torch.long)
        labels         = torch.full((B, max_len), -100,               dtype=torch.long)

        for i, s in enumerate(samples):
            L = s["input_ids"].size(0)
            input_ids[i, :L]      = s["input_ids"]
            attention_mask[i, :L] = 1
            labels[i, :L]         = s["labels"]

        return {
            "images":         images,         # (B, ...) image-shape-dependent
            "input_ids":      input_ids,      # (B, max_len)
            "attention_mask": attention_mask, # (B, max_len)
            "labels":         labels,         # (B, max_len)
        }


@dataclass
class ITCDataCollator:
    """
    Collator for Stage-1 ITC mode. Each sample is {image, text_embed}; we just
    stack both into batch tensors. No text padding (no token sequences here).
    """

    def __call__(self, samples: List[Dict]) -> Dict[str, torch.Tensor]:
        images      = torch.stack([s["image"] for s in samples])       # (B, ...)
        text_embeds = torch.stack([s["text_embed"] for s in samples])  # (B, proj_dim)
        return {"images": images, "text_embeds": text_embeds}