Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License. | |
| import warnings | |
| import torch | |
| from torch.nn.utils.rnn import pad_sequence | |
| from videomind.constants import IGNORE_INDEX | |
| class HybridDataCollator(object): | |
| def __init__(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| def __call__(self, batch): | |
| input_ids = [d['input_ids'] for d in batch] | |
| input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) | |
| labels = [d['labels'] for d in batch] | |
| labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) | |
| assert input_ids.size() == labels.size() | |
| seq_len, max_len = input_ids.size(1), self.tokenizer.model_max_length | |
| if seq_len > max_len: | |
| warnings.warn(f'The length of input sequence is exceeding model max length: {seq_len} > {max_len}') | |
| input_ids, labels = input_ids[:, :max_len], labels[:, :max_len] | |
| data = dict(input_ids=input_ids, labels=labels, attention_mask=input_ids != self.tokenizer.pad_token_id) | |
| for key in ('pixel_values', 'pixel_values_videos', 'image_grid_thw', 'video_grid_thw'): | |
| if key in batch[0]: | |
| data[key] = torch.cat([d[key] for d in batch]) | |
| for key in ('timestamps', 'saliency', 'pos_clip'): | |
| if key in batch[0]: | |
| data[key] = [d[key] for d in batch] | |
| return data | |