""" collator.py ----------- Custom DataCollator that handles variable-length sequences and stacks images into a batch tensor. Dynamic padding: input_ids / labels / attention_mask are padded to the maximum length WITHIN EACH BATCH, not to `cutoff_len`. Batches drawn mostly from short tasks (e.g. VQA) skip the wasted compute on padded positions entirely — Llama still runs every matmul on every position, so cutting off the padding tail is a direct FLOP saving (~1.5–2× on this project's task mix; see commit notes). """ from dataclasses import dataclass from typing import List, Dict import torch @dataclass class CXRDataCollator: """ Collates a list of dataset samples into a batch with **dynamic padding**. Args: pad_token_id: token id used to pad input_ids. labels are padded with -100 (HF Trainer's ignore index for cross-entropy). """ pad_token_id: int = 0 def __call__(self, samples: List[Dict]) -> Dict[str, torch.Tensor]: # Images have fixed shape per dataset config (single-image (C,H,W), # multi-image (N,C,H,W), or cached features (P,D)/(N,P,D)) — torch.stack # works for any of them. images = torch.stack([s["image"] for s in samples]) # ── Dynamic text padding ───────────────────────────────────────── max_len = max(s["input_ids"].size(0) for s in samples) B = len(samples) input_ids = torch.full((B, max_len), self.pad_token_id, dtype=torch.long) attention_mask = torch.zeros((B, max_len), dtype=torch.long) labels = torch.full((B, max_len), -100, dtype=torch.long) for i, s in enumerate(samples): L = s["input_ids"].size(0) input_ids[i, :L] = s["input_ids"] attention_mask[i, :L] = 1 labels[i, :L] = s["labels"] return { "images": images, # (B, ...) image-shape-dependent "input_ids": input_ids, # (B, max_len) "attention_mask": attention_mask, # (B, max_len) "labels": labels, # (B, max_len) } @dataclass class ITCDataCollator: """ Collator for Stage-1 ITC mode. Each sample is {image, text_embed}; we just stack both into batch tensors. No text padding (no token sequences here). """ def __call__(self, samples: List[Dict]) -> Dict[str, torch.Tensor]: images = torch.stack([s["image"] for s in samples]) # (B, ...) text_embeds = torch.stack([s["text_embed"] for s in samples]) # (B, proj_dim) return {"images": images, "text_embeds": text_embeds}