| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from tokenizer.tiktoken import IGNORE_INDEX |
|
|
| from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
| def padded_collate( |
| batch: List[Dict[str, List[int]]], |
| padding_idx: int = 0, |
| ignore_idx: int = -100, |
| ) -> Dict[str, torch.Tensor]: |
| """Pad a batch of sequences to the longest sequence length in the batch, and |
| convert integer lists to tensors. |
| |
| Args: |
| batch (List[Dict[str, List[int]]]): A list of dictionaries containing input, label pairs. |
| padding_idx (int): Padding index for input ids. Defaults to 0. |
| ignore_idx (int): Padding index for labels. Defaults to -100. |
| |
| Returns: |
| Dict[str, torch.Tensor]: Collated input and label tensors. |
| |
| Example: |
| >>> token_pairs = [ |
| >>> {"input_ids": [1, 2, 3], "labels": [4, 5, 6]}, |
| >>> {"input_ids": [7,], "labels": [10,]}, |
| >>> ] |
| >>> collated = padded_collate( |
| >>> batch=token_pairs, |
| >>> padding_idx=padding_idx, |
| >>> ignore_idx=ignore_idx, |
| >>> ) |
| >>> collated["input_ids"] |
| >>> tensor([[1, 2, 3], [7, 0, 0]]) |
| >>> collated["labels"] |
| >>> tensor([[4, 5, 6], [10, -100, -100]]) |
| """ |
| input_ids = pad_sequence( |
| [x["input_ids"] for x in batch], |
| batch_first=True, |
| padding_value=padding_idx, |
| ) |
| labels = pad_sequence( |
| [x["labels"] for x in batch], |
| batch_first=True, |
| padding_value=ignore_idx, |
| ) |
|
|
| input_ids_seq_len = input_ids.shape[-1] |
| labels_seq_len = labels.shape[-1] |
|
|
| |
| if input_ids_seq_len > labels_seq_len: |
| labels = F.pad( |
| labels, (0, input_ids_seq_len - labels_seq_len), value=ignore_idx |
| ) |
| elif labels_seq_len > input_ids_seq_len: |
| input_ids = F.pad( |
| input_ids, |
| (0, labels_seq_len - input_ids_seq_len), |
| value=padding_idx, |
| ) |
| return {"input_ids": input_ids, "labels": labels} |
|
|
|
|
| |
| @dataclass |
| class MultiModalCollator: |
| padding_idx: int = 128004 |
| ignore_idx: int = IGNORE_INDEX |
| pad_max_tiles: Optional[int] = None |
| pad_max_images: Optional[int] = None |
|
|
| def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
| """Pad a batch of text sequences, tiled image tensors, aspect ratios, |
| and cross attention masks. This can be used for both training and inference. |
| |
| ``batch`` is expected to be a list of sample dicts containing the following:: |
| - "input_ids": List[int] of length text_seq_len, varies across samples |
| - "labels": List[int] of length text_seq_len, varies across samples |
| - "encoder_input": Dict[str, List[torch.Tensor]] |
| - "images": List[torch.Tensor], each with shape (n_tiles, c, h, w) |
| - "aspect_ratio": List[torch.Tensor], each with shape (2, ) to indicate h_ratio, w_ratio |
| |
| Shape notation: |
| - c = channel dim |
| - h = height dim |
| - w = weight dim |
| |
| Note: |
| For each element in the batch, ``len(images) == len(aspect_ratio)``. |
| |
| This collater does the following: |
| (1) Pad text sequence and encoder mask to the longest sequence length in the batch |
| (2) Pad image tensors in the tile dimension with zeros to the largest number |
| of tiles in the batch |
| (3) Add empty images of zeros to samples up to max number of images in the batch |
| (4) Pad aspect ratios with (1,1) for all added padding images |
| |
| Args: |
| batch (List[Dict[str, Any]]): A list of sample dicts containing input_ids, |
| labels, images, and aspect_ratio. |
| padding_idx (int): Padding index for input token ids. Defaults to 0. |
| ignore_idx (int): Padding index for labels. Defaults to -100. |
| pad_max_tiles (Optional[int]): Maximum number of tiles to pad to. If None, will pad to the largest number of tiles |
| in the batch. Defaults to None. |
| pad_max_images (Optional[int]): Maximum number of images to pad to. If None, will pad to the largest number of images |
| in the batch. Defaults to None. |
| |
| Returns: |
| Dict[str, Tensor]: Collated tokens, labels, images, aspect_ratio tensors. |
| - tokens: Tensor of shape (bsz, max_seq_len) |
| - labels: Tensor of shape (bsz, max_seq_len) |
| - images: Tensor of shape (bsz, max_num_images, max_num_tiles, c, h, w) |
| - aspect_ratio: Tensor of shape (bsz, max_num_images, 2) |
| |
| Example: |
| >>> image_id = 1 |
| >>> tokens_per_tile = 5 |
| >>> c, h, w = 1, 1, 1 |
| >>> batch = [ |
| ... { |
| ... "input_ids": [1, 2, 1, 3], "labels": [4, 5, 6, 7], |
| ... "encoder_input": { |
| ... # One image with two tiles, one image with three tiles |
| ... "images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)], |
| ... "aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])], |
| ... }, |
| ... }, |
| ... { |
| ... "input_ids": [1, 4], "labels": [8, 9], |
| ... "encoder_input": { |
| ... # One image with four tiles |
| ... "images": [torch.ones(4, c, h, w)], |
| ... "aspect_ratio": [torch.tensor([2, 2])], |
| ... }, |
| ... }, |
| ... ] |
| ... collator = MultiModalCollator(pad_max_tiles=4) |
| >>> model_inputs = collator(batch=batch) |
| >>> print(model_inputs["input_ids"]) |
| tensor([[1, 2, 1, 3], |
| [1, 4, 0, 0]]) |
| >>> print(model_inputs["labels"]) |
| tensor([[4, 5, 6, 7], |
| [8, 9, -100, -100]]) |
| >>> print(model_inputs["encoder_input"]["images"].shape) # (bsz, max_num_images, max_num_tiles, c, h, w) |
| torch.Size([2, 2, 4, 1, 1, 1]) |
| >>> print(model_inputs["encoder_input"]["aspect_ratio"].shape) # (bsz, max_num_images, 2) |
| torch.Size([2, 2, 2]) |
| >>> print(model_inputs["encoder_input"]["images"][0, 0, ...]) # Image with two tiles got padded to four |
| tensor([[[[1.]]], [[[1.]]], [[[0.]]], [[[0.]]]]) |
| >>> print(model_inputs["encoder_input"]["images"][0, 1, ...]) # Image with three tiles got padded to four |
| tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[0.]]]]) |
| >>> print(model_inputs["encoder_input"]["images"][1, 0, ...]) # Image with four tiles did not get padded |
| tensor([[[[1.]]], [[[1.]]], [[[1.]]], [[[1.]]]]) |
| >>> print(model_inputs["encoder_input"]["images"][1, 1, ...]) # Extra padding image was added to second sample |
| tensor([[[[0.]]], [[[0.]]], [[[0.]]], [[[0.]]]]) |
| """ |
| |
| text_only = [ |
| {"input_ids": sample["input_ids"], "labels": sample["labels"]} |
| for sample in batch |
| ] |
| collated_text = padded_collate(text_only, self.padding_idx, self.ignore_idx) |
|
|
| if self.pad_max_tiles is None: |
| |
| max_num_tiles = max(sample["images_tiles"].shape[0] for sample in batch) |
| else: |
| max_num_tiles = self.pad_max_tiles |
|
|
| |
| batch_images = [] |
| batch_aspect_ratios = [] |
|
|
| for sample in batch: |
| sample_images = [] |
| for image in sample["encoder_input"]["images"]: |
| |
| n_tiles = image.shape[0] |
| |
| |
| padding_tiles = max_num_tiles - n_tiles |
|
|
| |
| padded_image = F.pad( |
| image, (0, 0, 0, 0, 0, 0, 0, padding_tiles), value=0 |
| ) |
|
|
| sample_images.append(padded_image) |
| |
| batch_images.append(torch.stack(sample_images)) |
| batch_aspect_ratios.append( |
| torch.stack(sample["encoder_input"]["aspect_ratio"]) |
| ) |
| |
| |
| collated_images = pad_sequence(batch_images, batch_first=True, padding_value=0) |
| |
| collated_aspect_ratios = pad_sequence( |
| batch_aspect_ratios, batch_first=True, padding_value=1 |
| ) |
|
|
| batch_dict = { |
| "input_ids": collated_text["input_ids"], |
| "labels": collated_text["labels"], |
| "encoder_input": { |
| "images": collated_images, |
| "aspect_ratio": collated_aspect_ratios, |
| }, |
| } |
|
|
| return batch_dict |
|
|