import torch from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX from typing import Dict, Sequence from torch.nn.utils.rnn import pad_sequence from functools import partial def collate_func_gen(instances: Sequence[Dict], pad_index: int = DEFAULT_PAD_TOKEN_INDEX): pixel_values, input_ids, input_lengths = [], [], [] for example in instances: pixel_values.append(example.pop('pixel_values')) input_lengths.append(len(example['input_ids'])) input_ids.append(example.pop('input_ids')) input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index) attention_mask = torch.zeros_like(input_ids).bool() for i in range(len(input_ids)): attention_mask[i, :input_lengths[i]] = True data_dict = dict(pixel_values=torch.stack(pixel_values), input_ids=input_ids, attention_mask=attention_mask) return {'data': data_dict, 'data_samples': None} def collate_func_und(instances, pad_index=DEFAULT_PAD_TOKEN_INDEX): input_ids_list, labels_list, pixel_values_list = [], [], [] for sample in instances: input_ids_list.append(torch.LongTensor(sample['input_ids'])) labels_list.append(torch.LongTensor(sample['labels'])) if 'pixel_values' in sample: pixel_values_list.append(sample['pixel_values']) ori_length = [len(input_ids_) for input_ids_ in input_ids_list] # right padding if len(instances) > 1: input_ids = pad_sequence( input_ids_list, batch_first=True, padding_value=pad_index) labels = pad_sequence( labels_list, batch_first=True, padding_value=IGNORE_INDEX) else: input_ids = torch.stack(input_ids_list) labels = torch.stack(labels_list) attention_mask = torch.zeros_like(input_ids).bool() for i, length in enumerate(ori_length): attention_mask[i, :length] = True # right padding data_dict = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels, 'pixel_values': torch.stack(pixel_values_list) if len(pixel_values_list) > 0 else None, # 'raw_conversations': raw_conversations_list, # 原始对话数据 # 'conversation_text': conversation_list # 格式化的对话文本 } return {'data': data_dict, 'data_samples': None} def collate_func_viewpoint2image(instances: Sequence[Dict], pad_index: int = DEFAULT_PAD_TOKEN_INDEX): """Collate function for viewpoint2image task.""" pixel_values, input_ids, input_lengths = [], [], [] viewpoint_params, viewpoint_valid_masks = [], [] for example in instances: pixel_values.append(example.pop('pixel_values')) input_lengths.append(len(example['input_ids'])) input_ids.append(example.pop('input_ids')) viewpoint_params.append(example.pop('viewpoint_params')) viewpoint_valid_masks.append(example.pop('viewpoint_valid_mask')) input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index) attention_mask = torch.zeros_like(input_ids).bool() for i in range(len(input_ids)): attention_mask[i, :input_lengths[i]] = True data_dict = dict( pixel_values=torch.stack(pixel_values), input_ids=input_ids, attention_mask=attention_mask, viewpoint_params=torch.stack(viewpoint_params), viewpoint_valid_mask=torch.stack(viewpoint_valid_masks), ) return {'data': data_dict, 'data_samples': None} def collate_func_image2viewpoint(instances: Sequence[Dict], pad_index: int = DEFAULT_PAD_TOKEN_INDEX): """Collate function for image2viewpoint task.""" input_ids_list, labels_list, pixel_values_list = [], [], [] viewpoint_params, viewpoint_valid_masks = [], [] for sample in instances: input_ids_list.append(torch.LongTensor(sample['input_ids'])) labels_list.append(torch.LongTensor(sample['labels'])) pixel_values_list.append(sample['pixel_values']) viewpoint_params.append(sample['viewpoint_params']) viewpoint_valid_masks.append(sample['viewpoint_valid_mask']) ori_length = [len(input_ids_) for input_ids_ in input_ids_list] # Right padding if len(instances) > 1: input_ids = pad_sequence( input_ids_list, batch_first=True, padding_value=pad_index) labels = pad_sequence( labels_list, batch_first=True, padding_value=IGNORE_INDEX) else: input_ids = torch.stack(input_ids_list) labels = torch.stack(labels_list) attention_mask = torch.zeros_like(input_ids).bool() for i, length in enumerate(ori_length): attention_mask[i, :length] = True data_dict = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels, 'pixel_values': torch.stack(pixel_values_list), 'viewpoint_params': torch.stack(viewpoint_params), 'viewpoint_valid_mask': torch.stack(viewpoint_valid_masks), } return {'data': data_dict, 'data_samples': None} def collate_func_relpose2image(instances: Sequence[Dict], pad_index: int = DEFAULT_PAD_TOKEN_INDEX): """Collate function for relpose2image task.""" src_pixel_values, tgt_pixel_values, input_ids, input_lengths = [], [], [], [] viewpoint_params = [] for example in instances: src_pixel_values.append(example.pop('src_pixel_values')) tgt_pixel_values.append(example.pop('tgt_pixel_values')) input_lengths.append(len(example['input_ids'])) input_ids.append(example.pop('input_ids')) viewpoint_params.append(example.pop('viewpoint_params')) input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index) attention_mask = torch.zeros_like(input_ids).bool() for i in range(len(input_ids)): attention_mask[i, :input_lengths[i]] = True data_dict = dict( src_pixel_values=torch.stack(src_pixel_values), tgt_pixel_values=torch.stack(tgt_pixel_values), input_ids=input_ids, attention_mask=attention_mask, viewpoint_params=torch.stack(viewpoint_params), ) return {'data': data_dict, 'data_samples': None} def collate_func_compass_viewpoint2image(instances: Sequence[Dict], pad_index: int = DEFAULT_PAD_TOKEN_INDEX): """Collate function for Compass multi-object viewpoint2image task. Handles variable number of objects (1 or 2) per sample by padding viewpoint_params to the maximum length in the batch. """ pixel_values, input_ids, input_lengths = [], [], [] viewpoint_params, viewpoint_valid_masks = [], [] num_objects_list = [] for example in instances: pixel_values.append(example.pop('pixel_values')) input_lengths.append(len(example['input_ids'])) input_ids.append(example.pop('input_ids')) viewpoint_params.append(example.pop('viewpoint_params')) viewpoint_valid_masks.append(example.pop('viewpoint_valid_mask')) num_objects_list.append(example.pop('num_objects')) # Pad viewpoint_params to max length in batch (handles mixed 1-obj and 2-obj samples) max_params = max(p.shape[0] for p in viewpoint_params) padded_viewpoint_params = [] padded_valid_masks = [] for params, mask in zip(viewpoint_params, viewpoint_valid_masks): if params.shape[0] < max_params: # Pad with zeros padding = torch.zeros(max_params - params.shape[0], dtype=params.dtype) params = torch.cat([params, padding]) # Pad mask with False (invalid) mask_padding = torch.zeros(max_params - mask.shape[0], dtype=mask.dtype) mask = torch.cat([mask, mask_padding]) padded_viewpoint_params.append(params) padded_valid_masks.append(mask) # Pad input_ids input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index) attention_mask = torch.zeros_like(input_ids).bool() for i in range(len(input_ids)): attention_mask[i, :input_lengths[i]] = True data_dict = dict( pixel_values=torch.stack(pixel_values), input_ids=input_ids, attention_mask=attention_mask, viewpoint_params=torch.stack(padded_viewpoint_params), viewpoint_valid_mask=torch.stack(padded_valid_masks), num_objects=torch.tensor(num_objects_list, dtype=torch.long), ) return {'data': data_dict, 'data_samples': None} class CollateConcat(object): def __init__(self, collate_fns, keys): self.keys = keys self.collate_fns = {} for key, collate_fn in zip(keys, collate_fns): func = collate_fn.pop('type') self.collate_fns[key] = partial(func, **collate_fn) def __call__(self, data_samples): data_samples = [data_sample for data_sample in data_samples if len(data_sample) > 0] data_dict = {} key = data_samples[0]['type'] data_dict[key] = self.collate_fns[key](data_samples)['data'] return {'data': data_dict, 'data_samples': None}