|
|
from typing import Dict, Sequence |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
from xtuner.parallel.sequence import (get_sequence_parallel_world_size, |
|
|
pad_for_sequence_parallel) |
|
|
from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX |
|
|
from einops import rearrange |
|
|
|
|
|
NON_VISION_TOKEN = -1 |
|
|
|
|
|
def generate_mm_pos_ids_singleit(input_ids, vpatch_id, h, w): |
|
|
if h * w == 0: |
|
|
nt = len(input_ids) |
|
|
|
|
|
position_id = torch.arange(nt).unsqueeze(-1).repeat(1, 3) |
|
|
assert len(input_ids) == position_id.size(0) |
|
|
position_id = rearrange(position_id, "slen d -> d slen").long() |
|
|
return position_id |
|
|
input_ids_pt = torch.Tensor(input_ids).int() |
|
|
vpatch_pos = torch.argwhere(input_ids_pt == vpatch_id) |
|
|
vpatch_start_pos = vpatch_pos[0].item() |
|
|
nt = len(input_ids) - (h * w) + 1 |
|
|
|
|
|
|
|
|
t_indices = torch.arange(1) |
|
|
h_indices = torch.arange(h) |
|
|
w_indices = torch.arange(w) |
|
|
v_pos_id = torch.stack(torch.meshgrid(t_indices, h_indices, w_indices, indexing='ij'), dim=0) |
|
|
v_pos_id = rearrange(v_pos_id, "d t h w -> (t h w) d") |
|
|
v_pos_id += vpatch_start_pos |
|
|
position_id = torch.cat( |
|
|
[ |
|
|
torch.arange(vpatch_start_pos).unsqueeze(-1).repeat(1, 3), |
|
|
v_pos_id, |
|
|
torch.arange(nt - vpatch_start_pos - 1).unsqueeze(-1).repeat(1, 3) + v_pos_id.max() + 1, |
|
|
], |
|
|
dim=0 |
|
|
) |
|
|
assert len(input_ids) == position_id.size(0) |
|
|
position_id = rearrange(position_id, "slen d -> d slen").long() |
|
|
|
|
|
return position_id |
|
|
|
|
|
def st_collate_fn(instances: Sequence[Dict], |
|
|
pad_index: int = DEFAULT_PAD_TOKEN_INDEX, |
|
|
return_hf_format: bool = False, |
|
|
use_varlen_attn: bool = False): |
|
|
seq_parallel_world_size = get_sequence_parallel_world_size() |
|
|
|
|
|
vision_patch_idx = instances[0].get('vision_patch_idx') |
|
|
|
|
|
input_ids, labels = [], [] |
|
|
has_image = any(inst.get('vision_patches') is not None for inst in instances) |
|
|
has_mask = any(inst.get('masks') is not None for inst in instances) |
|
|
|
|
|
if use_varlen_attn: |
|
|
position_ids, cumulative_len = [], [] |
|
|
assert len(instances) == 1, ( |
|
|
f'If utilizing varlen attention, the batch size should be' |
|
|
f' set to 1, but got {len(instances)}') |
|
|
assert not has_image, 'Currently, it is not configured to ' |
|
|
'accommodate the use of varlen Attention in multimodal training' |
|
|
|
|
|
patch_nums_per_images = [] |
|
|
vision_start_end = [] |
|
|
vision_patch_indices = [] |
|
|
if has_image: |
|
|
vision_patches = [] |
|
|
else: |
|
|
vision_patches = None |
|
|
if has_mask: |
|
|
masks = [] |
|
|
else: |
|
|
masks = None |
|
|
|
|
|
_vision_indexes_prefix = 0 |
|
|
for example in instances: |
|
|
input_ids.append(torch.LongTensor(example['input_ids'])) |
|
|
labels.append(torch.LongTensor(example['labels'])) |
|
|
patch_nums_per_images.append(example['patch_nums_per_images']) |
|
|
vision_start_end.append(example['vision_start_end']) |
|
|
|
|
|
|
|
|
batch_vision_patch_indices = torch.LongTensor(example['vision_patch_indices']) |
|
|
batch_vision_patch_indices[batch_vision_patch_indices!=NON_VISION_TOKEN] += _vision_indexes_prefix |
|
|
_vision_indexes_prefix = max(torch.max(batch_vision_patch_indices), 0) |
|
|
vision_patch_indices.append(batch_vision_patch_indices) |
|
|
|
|
|
if use_varlen_attn: |
|
|
cumulative_len.append(torch.IntTensor(example['cumulative_len'])) |
|
|
position_ids.append(torch.LongTensor(example['position_ids'])) |
|
|
|
|
|
if has_image: |
|
|
if 'vision_patches' in example.keys(): |
|
|
vision_patches.append(example['vision_patches']) |
|
|
if has_mask: |
|
|
if 'masks' in example.keys() and example['masks'] is not None: |
|
|
masks.append(example['masks']) |
|
|
else: |
|
|
masks.append(None) |
|
|
|
|
|
ori_length = [len(ids) for ids in input_ids] |
|
|
if len(instances) > 1: |
|
|
input_ids = pad_sequence( |
|
|
input_ids, batch_first=True, padding_value=pad_index) |
|
|
labels = pad_sequence( |
|
|
labels, batch_first=True, padding_value=IGNORE_INDEX) |
|
|
vision_patch_indices = pad_sequence( |
|
|
vision_patch_indices, batch_first=True, padding_value=NON_VISION_TOKEN) |
|
|
else: |
|
|
input_ids = torch.stack(input_ids) |
|
|
labels = torch.stack(labels) |
|
|
vision_patch_indices = torch.stack(vision_patch_indices) |
|
|
|
|
|
if use_varlen_attn: |
|
|
assert input_ids.size(1) % seq_parallel_world_size == 0 |
|
|
attention_mask = None |
|
|
position_ids = torch.stack(position_ids, dim=0) |
|
|
else: |
|
|
|
|
|
|
|
|
attention_mask = torch.zeros(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]).bool() |
|
|
for i, length in enumerate(ori_length): |
|
|
attention_mask[i, 0, :length, :length] = create_single_prefix_mask(vision_start_end[i], length) |
|
|
|
|
|
bs, seq_len = input_ids.shape |
|
|
|
|
|
position_ids = [] |
|
|
for input_id, patch_nums_per_image in zip(input_ids, patch_nums_per_images): |
|
|
position_id = generate_mm_pos_ids_singleit( |
|
|
input_id.cpu().numpy().tolist(), vision_patch_idx, |
|
|
patch_nums_per_image[0], patch_nums_per_image[1]) |
|
|
position_ids.append(position_id) |
|
|
position_ids = torch.stack(position_ids, dim=1) |
|
|
|
|
|
if seq_parallel_world_size > 1: |
|
|
input_ids = pad_for_sequence_parallel(input_ids, pad_index) |
|
|
labels = pad_for_sequence_parallel(labels, IGNORE_INDEX) |
|
|
position_ids = pad_for_sequence_parallel(position_ids, 0) |
|
|
if attention_mask is not None: |
|
|
attention_mask = pad_for_sequence_parallel(attention_mask, 0) |
|
|
|
|
|
if has_image: |
|
|
if len(vision_patches) == 0: |
|
|
vision_patches = None |
|
|
else: |
|
|
vision_patches = torch.cat(vision_patches, dim=0) |
|
|
|
|
|
if use_varlen_attn: |
|
|
max_seqlen = ( |
|
|
cumulative_len[0][1:] - |
|
|
cumulative_len[0][:-1]).max().item() |
|
|
data_dict = { |
|
|
'input_ids': input_ids, |
|
|
'cumulative_len': cumulative_len, |
|
|
'position_ids': position_ids, |
|
|
'labels': labels, |
|
|
'max_seqlen': max_seqlen, |
|
|
'vision_patch_indices': vision_patch_indices, |
|
|
'masks': masks, |
|
|
'vision_patches': vision_patches, |
|
|
'patch_nums_per_images': patch_nums_per_images |
|
|
} |
|
|
else: |
|
|
data_dict = { |
|
|
'input_ids': input_ids, |
|
|
'attention_mask': attention_mask, |
|
|
'position_ids': position_ids, |
|
|
'labels': labels, |
|
|
'vision_patch_indices': vision_patch_indices, |
|
|
'masks': masks, |
|
|
'vision_patches': vision_patches, |
|
|
'patch_nums_per_images': patch_nums_per_images |
|
|
} |
|
|
|
|
|
if return_hf_format: |
|
|
return data_dict |
|
|
else: |
|
|
return {'data': data_dict, 'data_samples': None} |
|
|
|
|
|
def create_single_prefix_mask(vision_start_end, max_len): |
|
|
if vision_start_end is None: |
|
|
|
|
|
attn_mask = torch.tril(torch.ones(max_len, max_len)) |
|
|
else: |
|
|
attn_mask = torch.zeros(max_len, max_len) |
|
|
attn_mask[vision_start_end[0]-1:vision_start_end[1]+1, vision_start_end[0]-1:vision_start_end[1]+1] = 1 |
|
|
causal_mask = torch.tril(torch.ones(max_len, max_len)) |
|
|
attn_mask = attn_mask.bool() | causal_mask.bool() |
|
|
return attn_mask |