Spaces:
Running on Zero
Running on Zero
| import torch | |
| from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX | |
| from typing import Dict, Sequence | |
| from torch.nn.utils.rnn import pad_sequence | |
| from functools import partial | |
| def collate_func_gen(instances: Sequence[Dict], | |
| pad_index: int = DEFAULT_PAD_TOKEN_INDEX): | |
| pixel_values, input_ids, input_lengths = [], [], [] | |
| for example in instances: | |
| pixel_values.append(example.pop('pixel_values')) | |
| input_lengths.append(len(example['input_ids'])) | |
| input_ids.append(example.pop('input_ids')) | |
| input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index) | |
| attention_mask = torch.zeros_like(input_ids).bool() | |
| for i in range(len(input_ids)): | |
| attention_mask[i, :input_lengths[i]] = True | |
| data_dict = dict(pixel_values=torch.stack(pixel_values), | |
| input_ids=input_ids, | |
| attention_mask=attention_mask) | |
| return {'data': data_dict, 'data_samples': None} | |
| def collate_func_und(instances, pad_index=DEFAULT_PAD_TOKEN_INDEX): | |
| input_ids_list, labels_list, pixel_values_list = [], [], [] | |
| for sample in instances: | |
| input_ids_list.append(torch.LongTensor(sample['input_ids'])) | |
| labels_list.append(torch.LongTensor(sample['labels'])) | |
| if 'pixel_values' in sample: | |
| pixel_values_list.append(sample['pixel_values']) | |
| ori_length = [len(input_ids_) for input_ids_ in input_ids_list] | |
| # right padding | |
| if len(instances) > 1: | |
| input_ids = pad_sequence( | |
| input_ids_list, batch_first=True, padding_value=pad_index) | |
| labels = pad_sequence( | |
| labels_list, batch_first=True, padding_value=IGNORE_INDEX) | |
| else: | |
| input_ids = torch.stack(input_ids_list) | |
| labels = torch.stack(labels_list) | |
| attention_mask = torch.zeros_like(input_ids).bool() | |
| for i, length in enumerate(ori_length): | |
| attention_mask[i, :length] = True # right padding | |
| data_dict = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'labels': labels, | |
| 'pixel_values': torch.stack(pixel_values_list) if len(pixel_values_list) > 0 else None, | |
| # 'raw_conversations': raw_conversations_list, # 原始对话数据 | |
| # 'conversation_text': conversation_list # 格式化的对话文本 | |
| } | |
| return {'data': data_dict, 'data_samples': None} | |
| def collate_func_viewpoint2image(instances: Sequence[Dict], | |
| pad_index: int = DEFAULT_PAD_TOKEN_INDEX): | |
| """Collate function for viewpoint2image task.""" | |
| pixel_values, input_ids, input_lengths = [], [], [] | |
| viewpoint_params, viewpoint_valid_masks = [], [] | |
| for example in instances: | |
| pixel_values.append(example.pop('pixel_values')) | |
| input_lengths.append(len(example['input_ids'])) | |
| input_ids.append(example.pop('input_ids')) | |
| viewpoint_params.append(example.pop('viewpoint_params')) | |
| viewpoint_valid_masks.append(example.pop('viewpoint_valid_mask')) | |
| input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index) | |
| attention_mask = torch.zeros_like(input_ids).bool() | |
| for i in range(len(input_ids)): | |
| attention_mask[i, :input_lengths[i]] = True | |
| data_dict = dict( | |
| pixel_values=torch.stack(pixel_values), | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| viewpoint_params=torch.stack(viewpoint_params), | |
| viewpoint_valid_mask=torch.stack(viewpoint_valid_masks), | |
| ) | |
| return {'data': data_dict, 'data_samples': None} | |
| def collate_func_image2viewpoint(instances: Sequence[Dict], | |
| pad_index: int = DEFAULT_PAD_TOKEN_INDEX): | |
| """Collate function for image2viewpoint task.""" | |
| input_ids_list, labels_list, pixel_values_list = [], [], [] | |
| viewpoint_params, viewpoint_valid_masks = [], [] | |
| for sample in instances: | |
| input_ids_list.append(torch.LongTensor(sample['input_ids'])) | |
| labels_list.append(torch.LongTensor(sample['labels'])) | |
| pixel_values_list.append(sample['pixel_values']) | |
| viewpoint_params.append(sample['viewpoint_params']) | |
| viewpoint_valid_masks.append(sample['viewpoint_valid_mask']) | |
| ori_length = [len(input_ids_) for input_ids_ in input_ids_list] | |
| # Right padding | |
| if len(instances) > 1: | |
| input_ids = pad_sequence( | |
| input_ids_list, batch_first=True, padding_value=pad_index) | |
| labels = pad_sequence( | |
| labels_list, batch_first=True, padding_value=IGNORE_INDEX) | |
| else: | |
| input_ids = torch.stack(input_ids_list) | |
| labels = torch.stack(labels_list) | |
| attention_mask = torch.zeros_like(input_ids).bool() | |
| for i, length in enumerate(ori_length): | |
| attention_mask[i, :length] = True | |
| data_dict = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'labels': labels, | |
| 'pixel_values': torch.stack(pixel_values_list), | |
| 'viewpoint_params': torch.stack(viewpoint_params), | |
| 'viewpoint_valid_mask': torch.stack(viewpoint_valid_masks), | |
| } | |
| return {'data': data_dict, 'data_samples': None} | |
| def collate_func_relpose2image(instances: Sequence[Dict], | |
| pad_index: int = DEFAULT_PAD_TOKEN_INDEX): | |
| """Collate function for relpose2image task.""" | |
| src_pixel_values, tgt_pixel_values, input_ids, input_lengths = [], [], [], [] | |
| viewpoint_params = [] | |
| for example in instances: | |
| src_pixel_values.append(example.pop('src_pixel_values')) | |
| tgt_pixel_values.append(example.pop('tgt_pixel_values')) | |
| input_lengths.append(len(example['input_ids'])) | |
| input_ids.append(example.pop('input_ids')) | |
| viewpoint_params.append(example.pop('viewpoint_params')) | |
| input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index) | |
| attention_mask = torch.zeros_like(input_ids).bool() | |
| for i in range(len(input_ids)): | |
| attention_mask[i, :input_lengths[i]] = True | |
| data_dict = dict( | |
| src_pixel_values=torch.stack(src_pixel_values), | |
| tgt_pixel_values=torch.stack(tgt_pixel_values), | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| viewpoint_params=torch.stack(viewpoint_params), | |
| ) | |
| return {'data': data_dict, 'data_samples': None} | |
| def collate_func_compass_viewpoint2image(instances: Sequence[Dict], | |
| pad_index: int = DEFAULT_PAD_TOKEN_INDEX): | |
| """Collate function for Compass multi-object viewpoint2image task. | |
| Handles variable number of objects (1 or 2) per sample by padding | |
| viewpoint_params to the maximum length in the batch. | |
| """ | |
| pixel_values, input_ids, input_lengths = [], [], [] | |
| viewpoint_params, viewpoint_valid_masks = [], [] | |
| num_objects_list = [] | |
| for example in instances: | |
| pixel_values.append(example.pop('pixel_values')) | |
| input_lengths.append(len(example['input_ids'])) | |
| input_ids.append(example.pop('input_ids')) | |
| viewpoint_params.append(example.pop('viewpoint_params')) | |
| viewpoint_valid_masks.append(example.pop('viewpoint_valid_mask')) | |
| num_objects_list.append(example.pop('num_objects')) | |
| # Pad viewpoint_params to max length in batch (handles mixed 1-obj and 2-obj samples) | |
| max_params = max(p.shape[0] for p in viewpoint_params) | |
| padded_viewpoint_params = [] | |
| padded_valid_masks = [] | |
| for params, mask in zip(viewpoint_params, viewpoint_valid_masks): | |
| if params.shape[0] < max_params: | |
| # Pad with zeros | |
| padding = torch.zeros(max_params - params.shape[0], dtype=params.dtype) | |
| params = torch.cat([params, padding]) | |
| # Pad mask with False (invalid) | |
| mask_padding = torch.zeros(max_params - mask.shape[0], dtype=mask.dtype) | |
| mask = torch.cat([mask, mask_padding]) | |
| padded_viewpoint_params.append(params) | |
| padded_valid_masks.append(mask) | |
| # Pad input_ids | |
| input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index) | |
| attention_mask = torch.zeros_like(input_ids).bool() | |
| for i in range(len(input_ids)): | |
| attention_mask[i, :input_lengths[i]] = True | |
| data_dict = dict( | |
| pixel_values=torch.stack(pixel_values), | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| viewpoint_params=torch.stack(padded_viewpoint_params), | |
| viewpoint_valid_mask=torch.stack(padded_valid_masks), | |
| num_objects=torch.tensor(num_objects_list, dtype=torch.long), | |
| ) | |
| return {'data': data_dict, 'data_samples': None} | |
| class CollateConcat(object): | |
| def __init__(self, collate_fns, keys): | |
| self.keys = keys | |
| self.collate_fns = {} | |
| for key, collate_fn in zip(keys, collate_fns): | |
| func = collate_fn.pop('type') | |
| self.collate_fns[key] = partial(func, **collate_fn) | |
| def __call__(self, data_samples): | |
| data_samples = [data_sample for data_sample in data_samples if len(data_sample) > 0] | |
| data_dict = {} | |
| key = data_samples[0]['type'] | |
| data_dict[key] = self.collate_fns[key](data_samples)['data'] | |
| return {'data': data_dict, 'data_samples': None} | |