""" collate.py """ import torch class CocoCollator: def __init__(self, pad_token_id): self.pad_token_id = pad_token_id def __call__(self, batch): # Stack pixel values pixel_values = torch.stack([item["pixel_values"] for item in batch], dim=0) # Collect caption fields input_ids = [item["input_ids"] for item in batch] attention_masks = [item["attention_mask"] for item in batch] # Pad input_ids input_ids_padded = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.pad_token_id ) attention_masks_padded = torch.nn.utils.rnn.pad_sequence( attention_masks, batch_first=True, padding_value=0 # attention padding = 0 ) image_ids = [item["image_id"] for item in batch] return { "pixel_values": pixel_values, "input_ids": input_ids_padded, "attention_mask": attention_masks_padded, "image_ids": image_ids, }