File size: 5,261 Bytes
002bd9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import logging
from ..models.sam.processing_sam import SamProcessor
from transformers import PreTrainedTokenizer
from typing import Optional, Union, List, Dict, Any
import torch
import pycocotools.mask
from collections import defaultdict
from .transforms import REGION_KEYS


logger = logging.getLogger(__name__)


class SamCaptionerDataCollator:
    label_pad_token_id: int = -100

    def __init__(self, tokenizer: PreTrainedTokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        # NOTE: Filter the batch for any sample samples based on "input_ids"
        nonempty_sample_indices = []
        for i, sample in enumerate(batch):
            if sample["input_ids"] is not None:
                nonempty_sample_indices.append(i)

        if len(nonempty_sample_indices) == 0:
            logger.error(f"batch is empty, skip this batch of data.")
            return None
        elif len(nonempty_sample_indices) < len(batch):
            num_skip = len(batch) - len(nonempty_sample_indices)
            logger.warning(f"batch is not empty, but some samples are empty, skip {num_skip} samples.")

        batch = [batch[i] for i in nonempty_sample_indices]

        # NOTE(xiaoke) dynamic padding
        # inputs_ids List[List[int]]
        # attention_mask List[List[int]]
        num_regions_per_sample = [len(sample["input_ids"]) for sample in batch]
        num_minimum_regions_per_sample = min(num_regions_per_sample)

        is_batch_of_regions = all(x == num_minimum_regions_per_sample for x in num_regions_per_sample)
        # NOTE: if num_masks_per_sample is larger than all the numbers of regions in the batch, we then need to chunk with the minimum number of batches
        if not is_batch_of_regions:
            logger.warning(
                f"is_batch_of_regions is False due to num_minimum_regions_per_sample {num_minimum_regions_per_sample} < num_regions_per_sample {num_regions_per_sample}. "
                "Thus we chunk the regions with the minimum number of regions in the batch."
            )

        flat_input_ids = []
        flat_attention_mask = []
        for sample in batch:
            # NOTE(xiaoke): pop out the input_ids and attention_mask
            flat_input_ids.extend(sample.pop("input_ids"))
            flat_attention_mask.extend(sample.pop("attention_mask"))

        # NOTE(xiaoke): pad the input_ids and attention_mask
        # which are already truncated to `model_max_length`
        encoding_tokenizer = self.tokenizer.pad(
            dict(input_ids=flat_input_ids, attention_mask=flat_attention_mask),
            padding=True,
            return_tensors="pt",
        )
        # add labels, pad with -100 to ignore in loss computation
        encoding_tokenizer["labels"] = self.prepare_labels(encoding_tokenizer)

        for k, v in encoding_tokenizer.items():
            encoding_tokenizer_ = v.split(num_regions_per_sample)
            encoding_tokenizer_ = [x[:num_minimum_regions_per_sample] for x in encoding_tokenizer_]
            encoding_tokenizer[k] = torch.stack(encoding_tokenizer_)

        # process other fields, e.g., `input_boxes`, `metadata_*`, etc.
        encoding_else = {}
        for k, v in batch[0].items():
            if v is None:
                # NOTE: if the value is None, we set it to None and not batchfity it.
                encoding_else[k] = None
            elif isinstance(v, torch.Tensor):
                if k in REGION_KEYS:
                    # NOTE(xiaoke): it is possible for the number of regions to be different
                    # i.e., less than num_masks_per_sample during training.
                    # NOTE(xiaoke): we make sure that eval_batch_size=1
                    encoding_else[k] = torch.stack([sample[k][:num_minimum_regions_per_sample] for sample in batch])
                else:
                    encoding_else[k] = torch.stack([sample[k] for sample in batch])
            else:
                encoding_else[k] = [sample[k] for sample in batch]

        return {
            **encoding_tokenizer,
            **encoding_else,
        }

    def prepare_labels(self, encoding_tokenizer):
        label_mask = encoding_tokenizer["attention_mask"].bool()
        labels = encoding_tokenizer["input_ids"].clone()
        labels.masked_fill_(~label_mask, self.label_pad_token_id)
        return labels


class SCADataCollator(SamCaptionerDataCollator):
    def prepare_labels(self, encoding_tokenizer):
        labels = super().prepare_labels(encoding_tokenizer)
        # XXX(xiaoke): since we do not add the <BOS> token in both training and inference stage.
        # Therefore, we need to shift the labels to the right by one. However, this leads to the situation where labels have one more token than the input_ids,
        # and the max_length of the labels is larger by 1 than tokeinizer.model_max_length, which cause error in trainer eval when compute eval loss due to mismatched last dim during cross-batch-region paddding.
        # Our solution is to trim the input_ids and attention_mask by 1.
        # Check `src/data/transforms/base_transforms.py:SCADataTransform:process_tokens` for more details.
        return torch.nn.functional.pad(labels, (1, 0), value=self.label_pad_token_id)