| """ |
| collator.py |
| ----------- |
| Custom DataCollator that handles variable-length sequences and |
| stacks images into a batch tensor. |
| |
| Dynamic padding: input_ids / labels / attention_mask are padded to the |
| maximum length WITHIN EACH BATCH, not to `cutoff_len`. Batches drawn |
| mostly from short tasks (e.g. VQA) skip the wasted compute on padded |
| positions entirely — Llama still runs every matmul on every position, |
| so cutting off the padding tail is a direct FLOP saving (~1.5–2× on |
| this project's task mix; see commit notes). |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import List, Dict |
| import torch |
|
|
|
|
| @dataclass |
| class CXRDataCollator: |
| """ |
| Collates a list of dataset samples into a batch with **dynamic padding**. |
| |
| Args: |
| pad_token_id: token id used to pad input_ids. labels are padded |
| with -100 (HF Trainer's ignore index for cross-entropy). |
| """ |
| pad_token_id: int = 0 |
|
|
| def __call__(self, samples: List[Dict]) -> Dict[str, torch.Tensor]: |
| |
| |
| |
| images = torch.stack([s["image"] for s in samples]) |
|
|
| |
| max_len = max(s["input_ids"].size(0) for s in samples) |
| B = len(samples) |
|
|
| input_ids = torch.full((B, max_len), self.pad_token_id, dtype=torch.long) |
| attention_mask = torch.zeros((B, max_len), dtype=torch.long) |
| labels = torch.full((B, max_len), -100, dtype=torch.long) |
|
|
| for i, s in enumerate(samples): |
| L = s["input_ids"].size(0) |
| input_ids[i, :L] = s["input_ids"] |
| attention_mask[i, :L] = 1 |
| labels[i, :L] = s["labels"] |
|
|
| return { |
| "images": images, |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| } |
|
|
|
|
| @dataclass |
| class ITCDataCollator: |
| """ |
| Collator for Stage-1 ITC mode. Each sample is {image, text_embed}; we just |
| stack both into batch tensors. No text padding (no token sequences here). |
| """ |
|
|
| def __call__(self, samples: List[Dict]) -> Dict[str, torch.Tensor]: |
| images = torch.stack([s["image"] for s in samples]) |
| text_embeds = torch.stack([s["text_embed"] for s in samples]) |
| return {"images": images, "text_embeds": text_embeds} |
|
|