|
|
import logging |
|
|
from ..models.sam.processing_sam import SamProcessor |
|
|
from transformers import PreTrainedTokenizer |
|
|
from typing import Optional, Union, List, Dict, Any |
|
|
import torch |
|
|
import pycocotools.mask |
|
|
from collections import defaultdict |
|
|
from .transforms import REGION_KEYS |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class SamCaptionerDataCollator: |
|
|
label_pad_token_id: int = -100 |
|
|
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizer): |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def __call__(self, batch): |
|
|
|
|
|
nonempty_sample_indices = [] |
|
|
for i, sample in enumerate(batch): |
|
|
if sample["input_ids"] is not None: |
|
|
nonempty_sample_indices.append(i) |
|
|
|
|
|
if len(nonempty_sample_indices) == 0: |
|
|
logger.error(f"batch is empty, skip this batch of data.") |
|
|
return None |
|
|
elif len(nonempty_sample_indices) < len(batch): |
|
|
num_skip = len(batch) - len(nonempty_sample_indices) |
|
|
logger.warning(f"batch is not empty, but some samples are empty, skip {num_skip} samples.") |
|
|
|
|
|
batch = [batch[i] for i in nonempty_sample_indices] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_regions_per_sample = [len(sample["input_ids"]) for sample in batch] |
|
|
num_minimum_regions_per_sample = min(num_regions_per_sample) |
|
|
|
|
|
is_batch_of_regions = all(x == num_minimum_regions_per_sample for x in num_regions_per_sample) |
|
|
|
|
|
if not is_batch_of_regions: |
|
|
logger.warning( |
|
|
f"is_batch_of_regions is False due to num_minimum_regions_per_sample {num_minimum_regions_per_sample} < num_regions_per_sample {num_regions_per_sample}. " |
|
|
"Thus we chunk the regions with the minimum number of regions in the batch." |
|
|
) |
|
|
|
|
|
flat_input_ids = [] |
|
|
flat_attention_mask = [] |
|
|
for sample in batch: |
|
|
|
|
|
flat_input_ids.extend(sample.pop("input_ids")) |
|
|
flat_attention_mask.extend(sample.pop("attention_mask")) |
|
|
|
|
|
|
|
|
|
|
|
encoding_tokenizer = self.tokenizer.pad( |
|
|
dict(input_ids=flat_input_ids, attention_mask=flat_attention_mask), |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
encoding_tokenizer["labels"] = self.prepare_labels(encoding_tokenizer) |
|
|
|
|
|
for k, v in encoding_tokenizer.items(): |
|
|
encoding_tokenizer_ = v.split(num_regions_per_sample) |
|
|
encoding_tokenizer_ = [x[:num_minimum_regions_per_sample] for x in encoding_tokenizer_] |
|
|
encoding_tokenizer[k] = torch.stack(encoding_tokenizer_) |
|
|
|
|
|
|
|
|
encoding_else = {} |
|
|
for k, v in batch[0].items(): |
|
|
if v is None: |
|
|
|
|
|
encoding_else[k] = None |
|
|
elif isinstance(v, torch.Tensor): |
|
|
if k in REGION_KEYS: |
|
|
|
|
|
|
|
|
|
|
|
encoding_else[k] = torch.stack([sample[k][:num_minimum_regions_per_sample] for sample in batch]) |
|
|
else: |
|
|
encoding_else[k] = torch.stack([sample[k] for sample in batch]) |
|
|
else: |
|
|
encoding_else[k] = [sample[k] for sample in batch] |
|
|
|
|
|
return { |
|
|
**encoding_tokenizer, |
|
|
**encoding_else, |
|
|
} |
|
|
|
|
|
def prepare_labels(self, encoding_tokenizer): |
|
|
label_mask = encoding_tokenizer["attention_mask"].bool() |
|
|
labels = encoding_tokenizer["input_ids"].clone() |
|
|
labels.masked_fill_(~label_mask, self.label_pad_token_id) |
|
|
return labels |
|
|
|
|
|
|
|
|
class SCADataCollator(SamCaptionerDataCollator): |
|
|
def prepare_labels(self, encoding_tokenizer): |
|
|
labels = super().prepare_labels(encoding_tokenizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return torch.nn.functional.pad(labels, (1, 0), value=self.label_pad_token_id) |
|
|
|