viewtoken-harmon-demo / src /datasets /collate_functions.py
XinxuanLu's picture
Initial demo
becf13a verified
import torch
from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
from typing import Dict, Sequence
from torch.nn.utils.rnn import pad_sequence
from functools import partial
def collate_func_gen(instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX):
pixel_values, input_ids, input_lengths = [], [], []
for example in instances:
pixel_values.append(example.pop('pixel_values'))
input_lengths.append(len(example['input_ids']))
input_ids.append(example.pop('input_ids'))
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index)
attention_mask = torch.zeros_like(input_ids).bool()
for i in range(len(input_ids)):
attention_mask[i, :input_lengths[i]] = True
data_dict = dict(pixel_values=torch.stack(pixel_values),
input_ids=input_ids,
attention_mask=attention_mask)
return {'data': data_dict, 'data_samples': None}
def collate_func_und(instances, pad_index=DEFAULT_PAD_TOKEN_INDEX):
input_ids_list, labels_list, pixel_values_list = [], [], []
for sample in instances:
input_ids_list.append(torch.LongTensor(sample['input_ids']))
labels_list.append(torch.LongTensor(sample['labels']))
if 'pixel_values' in sample:
pixel_values_list.append(sample['pixel_values'])
ori_length = [len(input_ids_) for input_ids_ in input_ids_list]
# right padding
if len(instances) > 1:
input_ids = pad_sequence(
input_ids_list, batch_first=True, padding_value=pad_index)
labels = pad_sequence(
labels_list, batch_first=True, padding_value=IGNORE_INDEX)
else:
input_ids = torch.stack(input_ids_list)
labels = torch.stack(labels_list)
attention_mask = torch.zeros_like(input_ids).bool()
for i, length in enumerate(ori_length):
attention_mask[i, :length] = True # right padding
data_dict = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'pixel_values': torch.stack(pixel_values_list) if len(pixel_values_list) > 0 else None,
# 'raw_conversations': raw_conversations_list, # 原始对话数据
# 'conversation_text': conversation_list # 格式化的对话文本
}
return {'data': data_dict, 'data_samples': None}
def collate_func_viewpoint2image(instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX):
"""Collate function for viewpoint2image task."""
pixel_values, input_ids, input_lengths = [], [], []
viewpoint_params, viewpoint_valid_masks = [], []
for example in instances:
pixel_values.append(example.pop('pixel_values'))
input_lengths.append(len(example['input_ids']))
input_ids.append(example.pop('input_ids'))
viewpoint_params.append(example.pop('viewpoint_params'))
viewpoint_valid_masks.append(example.pop('viewpoint_valid_mask'))
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index)
attention_mask = torch.zeros_like(input_ids).bool()
for i in range(len(input_ids)):
attention_mask[i, :input_lengths[i]] = True
data_dict = dict(
pixel_values=torch.stack(pixel_values),
input_ids=input_ids,
attention_mask=attention_mask,
viewpoint_params=torch.stack(viewpoint_params),
viewpoint_valid_mask=torch.stack(viewpoint_valid_masks),
)
return {'data': data_dict, 'data_samples': None}
def collate_func_image2viewpoint(instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX):
"""Collate function for image2viewpoint task."""
input_ids_list, labels_list, pixel_values_list = [], [], []
viewpoint_params, viewpoint_valid_masks = [], []
for sample in instances:
input_ids_list.append(torch.LongTensor(sample['input_ids']))
labels_list.append(torch.LongTensor(sample['labels']))
pixel_values_list.append(sample['pixel_values'])
viewpoint_params.append(sample['viewpoint_params'])
viewpoint_valid_masks.append(sample['viewpoint_valid_mask'])
ori_length = [len(input_ids_) for input_ids_ in input_ids_list]
# Right padding
if len(instances) > 1:
input_ids = pad_sequence(
input_ids_list, batch_first=True, padding_value=pad_index)
labels = pad_sequence(
labels_list, batch_first=True, padding_value=IGNORE_INDEX)
else:
input_ids = torch.stack(input_ids_list)
labels = torch.stack(labels_list)
attention_mask = torch.zeros_like(input_ids).bool()
for i, length in enumerate(ori_length):
attention_mask[i, :length] = True
data_dict = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'pixel_values': torch.stack(pixel_values_list),
'viewpoint_params': torch.stack(viewpoint_params),
'viewpoint_valid_mask': torch.stack(viewpoint_valid_masks),
}
return {'data': data_dict, 'data_samples': None}
def collate_func_relpose2image(instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX):
"""Collate function for relpose2image task."""
src_pixel_values, tgt_pixel_values, input_ids, input_lengths = [], [], [], []
viewpoint_params = []
for example in instances:
src_pixel_values.append(example.pop('src_pixel_values'))
tgt_pixel_values.append(example.pop('tgt_pixel_values'))
input_lengths.append(len(example['input_ids']))
input_ids.append(example.pop('input_ids'))
viewpoint_params.append(example.pop('viewpoint_params'))
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index)
attention_mask = torch.zeros_like(input_ids).bool()
for i in range(len(input_ids)):
attention_mask[i, :input_lengths[i]] = True
data_dict = dict(
src_pixel_values=torch.stack(src_pixel_values),
tgt_pixel_values=torch.stack(tgt_pixel_values),
input_ids=input_ids,
attention_mask=attention_mask,
viewpoint_params=torch.stack(viewpoint_params),
)
return {'data': data_dict, 'data_samples': None}
def collate_func_compass_viewpoint2image(instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX):
"""Collate function for Compass multi-object viewpoint2image task.
Handles variable number of objects (1 or 2) per sample by padding
viewpoint_params to the maximum length in the batch.
"""
pixel_values, input_ids, input_lengths = [], [], []
viewpoint_params, viewpoint_valid_masks = [], []
num_objects_list = []
for example in instances:
pixel_values.append(example.pop('pixel_values'))
input_lengths.append(len(example['input_ids']))
input_ids.append(example.pop('input_ids'))
viewpoint_params.append(example.pop('viewpoint_params'))
viewpoint_valid_masks.append(example.pop('viewpoint_valid_mask'))
num_objects_list.append(example.pop('num_objects'))
# Pad viewpoint_params to max length in batch (handles mixed 1-obj and 2-obj samples)
max_params = max(p.shape[0] for p in viewpoint_params)
padded_viewpoint_params = []
padded_valid_masks = []
for params, mask in zip(viewpoint_params, viewpoint_valid_masks):
if params.shape[0] < max_params:
# Pad with zeros
padding = torch.zeros(max_params - params.shape[0], dtype=params.dtype)
params = torch.cat([params, padding])
# Pad mask with False (invalid)
mask_padding = torch.zeros(max_params - mask.shape[0], dtype=mask.dtype)
mask = torch.cat([mask, mask_padding])
padded_viewpoint_params.append(params)
padded_valid_masks.append(mask)
# Pad input_ids
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_index)
attention_mask = torch.zeros_like(input_ids).bool()
for i in range(len(input_ids)):
attention_mask[i, :input_lengths[i]] = True
data_dict = dict(
pixel_values=torch.stack(pixel_values),
input_ids=input_ids,
attention_mask=attention_mask,
viewpoint_params=torch.stack(padded_viewpoint_params),
viewpoint_valid_mask=torch.stack(padded_valid_masks),
num_objects=torch.tensor(num_objects_list, dtype=torch.long),
)
return {'data': data_dict, 'data_samples': None}
class CollateConcat(object):
def __init__(self, collate_fns, keys):
self.keys = keys
self.collate_fns = {}
for key, collate_fn in zip(keys, collate_fns):
func = collate_fn.pop('type')
self.collate_fns[key] = partial(func, **collate_fn)
def __call__(self, data_samples):
data_samples = [data_sample for data_sample in data_samples if len(data_sample) > 0]
data_dict = {}
key = data_samples[0]['type']
data_dict[key] = self.collate_fns[key](data_samples)['data']
return {'data': data_dict, 'data_samples': None}