File size: 5,261 Bytes
002bd9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import logging
from ..models.sam.processing_sam import SamProcessor
from transformers import PreTrainedTokenizer
from typing import Optional, Union, List, Dict, Any
import torch
import pycocotools.mask
from collections import defaultdict
from .transforms import REGION_KEYS
logger = logging.getLogger(__name__)
class SamCaptionerDataCollator:
label_pad_token_id: int = -100
def __init__(self, tokenizer: PreTrainedTokenizer):
self.tokenizer = tokenizer
def __call__(self, batch):
# NOTE: Filter the batch for any sample samples based on "input_ids"
nonempty_sample_indices = []
for i, sample in enumerate(batch):
if sample["input_ids"] is not None:
nonempty_sample_indices.append(i)
if len(nonempty_sample_indices) == 0:
logger.error(f"batch is empty, skip this batch of data.")
return None
elif len(nonempty_sample_indices) < len(batch):
num_skip = len(batch) - len(nonempty_sample_indices)
logger.warning(f"batch is not empty, but some samples are empty, skip {num_skip} samples.")
batch = [batch[i] for i in nonempty_sample_indices]
# NOTE(xiaoke) dynamic padding
# inputs_ids List[List[int]]
# attention_mask List[List[int]]
num_regions_per_sample = [len(sample["input_ids"]) for sample in batch]
num_minimum_regions_per_sample = min(num_regions_per_sample)
is_batch_of_regions = all(x == num_minimum_regions_per_sample for x in num_regions_per_sample)
# NOTE: if num_masks_per_sample is larger than all the numbers of regions in the batch, we then need to chunk with the minimum number of batches
if not is_batch_of_regions:
logger.warning(
f"is_batch_of_regions is False due to num_minimum_regions_per_sample {num_minimum_regions_per_sample} < num_regions_per_sample {num_regions_per_sample}. "
"Thus we chunk the regions with the minimum number of regions in the batch."
)
flat_input_ids = []
flat_attention_mask = []
for sample in batch:
# NOTE(xiaoke): pop out the input_ids and attention_mask
flat_input_ids.extend(sample.pop("input_ids"))
flat_attention_mask.extend(sample.pop("attention_mask"))
# NOTE(xiaoke): pad the input_ids and attention_mask
# which are already truncated to `model_max_length`
encoding_tokenizer = self.tokenizer.pad(
dict(input_ids=flat_input_ids, attention_mask=flat_attention_mask),
padding=True,
return_tensors="pt",
)
# add labels, pad with -100 to ignore in loss computation
encoding_tokenizer["labels"] = self.prepare_labels(encoding_tokenizer)
for k, v in encoding_tokenizer.items():
encoding_tokenizer_ = v.split(num_regions_per_sample)
encoding_tokenizer_ = [x[:num_minimum_regions_per_sample] for x in encoding_tokenizer_]
encoding_tokenizer[k] = torch.stack(encoding_tokenizer_)
# process other fields, e.g., `input_boxes`, `metadata_*`, etc.
encoding_else = {}
for k, v in batch[0].items():
if v is None:
# NOTE: if the value is None, we set it to None and not batchfity it.
encoding_else[k] = None
elif isinstance(v, torch.Tensor):
if k in REGION_KEYS:
# NOTE(xiaoke): it is possible for the number of regions to be different
# i.e., less than num_masks_per_sample during training.
# NOTE(xiaoke): we make sure that eval_batch_size=1
encoding_else[k] = torch.stack([sample[k][:num_minimum_regions_per_sample] for sample in batch])
else:
encoding_else[k] = torch.stack([sample[k] for sample in batch])
else:
encoding_else[k] = [sample[k] for sample in batch]
return {
**encoding_tokenizer,
**encoding_else,
}
def prepare_labels(self, encoding_tokenizer):
label_mask = encoding_tokenizer["attention_mask"].bool()
labels = encoding_tokenizer["input_ids"].clone()
labels.masked_fill_(~label_mask, self.label_pad_token_id)
return labels
class SCADataCollator(SamCaptionerDataCollator):
def prepare_labels(self, encoding_tokenizer):
labels = super().prepare_labels(encoding_tokenizer)
# XXX(xiaoke): since we do not add the <BOS> token in both training and inference stage.
# Therefore, we need to shift the labels to the right by one. However, this leads to the situation where labels have one more token than the input_ids,
# and the max_length of the labels is larger by 1 than tokeinizer.model_max_length, which cause error in trainer eval when compute eval loss due to mismatched last dim during cross-batch-region paddding.
# Our solution is to trim the input_ids and attention_mask by 1.
# Check `src/data/transforms/base_transforms.py:SCADataTransform:process_tokens` for more details.
return torch.nn.functional.pad(labels, (1, 0), value=self.label_pad_token_id)
|