cxr-vlm-code / data /collator.py
convitom
f
8356dae
"""
collator.py
-----------
Custom DataCollator that handles variable-length sequences and
stacks images into a batch tensor.
Dynamic padding: input_ids / labels / attention_mask are padded to the
maximum length WITHIN EACH BATCH, not to `cutoff_len`. Batches drawn
mostly from short tasks (e.g. VQA) skip the wasted compute on padded
positions entirely — Llama still runs every matmul on every position,
so cutting off the padding tail is a direct FLOP saving (~1.5–2× on
this project's task mix; see commit notes).
"""
from dataclasses import dataclass
from typing import List, Dict
import torch
@dataclass
class CXRDataCollator:
"""
Collates a list of dataset samples into a batch with **dynamic padding**.
Args:
pad_token_id: token id used to pad input_ids. labels are padded
with -100 (HF Trainer's ignore index for cross-entropy).
"""
pad_token_id: int = 0
def __call__(self, samples: List[Dict]) -> Dict[str, torch.Tensor]:
# Images have fixed shape per dataset config (single-image (C,H,W),
# multi-image (N,C,H,W), or cached features (P,D)/(N,P,D)) — torch.stack
# works for any of them.
images = torch.stack([s["image"] for s in samples])
# ── Dynamic text padding ─────────────────────────────────────────
max_len = max(s["input_ids"].size(0) for s in samples)
B = len(samples)
input_ids = torch.full((B, max_len), self.pad_token_id, dtype=torch.long)
attention_mask = torch.zeros((B, max_len), dtype=torch.long)
labels = torch.full((B, max_len), -100, dtype=torch.long)
for i, s in enumerate(samples):
L = s["input_ids"].size(0)
input_ids[i, :L] = s["input_ids"]
attention_mask[i, :L] = 1
labels[i, :L] = s["labels"]
return {
"images": images, # (B, ...) image-shape-dependent
"input_ids": input_ids, # (B, max_len)
"attention_mask": attention_mask, # (B, max_len)
"labels": labels, # (B, max_len)
}
@dataclass
class ITCDataCollator:
"""
Collator for Stage-1 ITC mode. Each sample is {image, text_embed}; we just
stack both into batch tensors. No text padding (no token sequences here).
"""
def __call__(self, samples: List[Dict]) -> Dict[str, torch.Tensor]:
images = torch.stack([s["image"] for s in samples]) # (B, ...)
text_embeds = torch.stack([s["text_embed"] for s in samples]) # (B, proj_dim)
return {"images": images, "text_embeds": text_embeds}