|
|
""" |
|
|
collate.py |
|
|
""" |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
class CocoCollator: |
|
|
def __init__(self, pad_token_id): |
|
|
self.pad_token_id = pad_token_id |
|
|
|
|
|
def __call__(self, batch): |
|
|
|
|
|
|
|
|
pixel_values = torch.stack([item["pixel_values"] for item in batch], dim=0) |
|
|
|
|
|
|
|
|
input_ids = [item["input_ids"] for item in batch] |
|
|
attention_masks = [item["attention_mask"] for item in batch] |
|
|
|
|
|
|
|
|
input_ids_padded = torch.nn.utils.rnn.pad_sequence( |
|
|
input_ids, batch_first=True, padding_value=self.pad_token_id |
|
|
) |
|
|
|
|
|
attention_masks_padded = torch.nn.utils.rnn.pad_sequence( |
|
|
attention_masks, batch_first=True, padding_value=0 |
|
|
) |
|
|
|
|
|
image_ids = [item["image_id"] for item in batch] |
|
|
|
|
|
return { |
|
|
"pixel_values": pixel_values, |
|
|
"input_ids": input_ids_padded, |
|
|
"attention_mask": attention_masks_padded, |
|
|
"image_ids": image_ids, |
|
|
} |
|
|
|