from typing import Dict, Sequence import numpy as np import torch from torch.nn.utils.rnn import pad_sequence from xtuner.parallel.sequence import (get_sequence_parallel_world_size, pad_for_sequence_parallel) from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX from einops import rearrange NON_VISION_TOKEN = -1 def generate_mm_pos_ids_singleit(input_ids, vpatch_id, h, w): if h * w == 0: nt = len(input_ids) # pure text position_id = torch.arange(nt).unsqueeze(-1).repeat(1, 3) assert len(input_ids) == position_id.size(0) position_id = rearrange(position_id, "slen d -> d slen").long() return position_id input_ids_pt = torch.Tensor(input_ids).int() vpatch_pos = torch.argwhere(input_ids_pt == vpatch_id) vpatch_start_pos = vpatch_pos[0].item() nt = len(input_ids) - (h * w) + 1 # v_pos t_indices = torch.arange(1) h_indices = torch.arange(h) w_indices = torch.arange(w) v_pos_id = torch.stack(torch.meshgrid(t_indices, h_indices, w_indices, indexing='ij'), dim=0) v_pos_id = rearrange(v_pos_id, "d t h w -> (t h w) d") # [h*w, 3] v_pos_id += vpatch_start_pos position_id = torch.cat( [ torch.arange(vpatch_start_pos).unsqueeze(-1).repeat(1, 3), v_pos_id, torch.arange(nt - vpatch_start_pos - 1).unsqueeze(-1).repeat(1, 3) + v_pos_id.max() + 1, ], dim=0 ) assert len(input_ids) == position_id.size(0) position_id = rearrange(position_id, "slen d -> d slen").long() return position_id def st_collate_fn(instances: Sequence[Dict], pad_index: int = DEFAULT_PAD_TOKEN_INDEX, return_hf_format: bool = False, use_varlen_attn: bool = False): seq_parallel_world_size = get_sequence_parallel_world_size() vision_patch_idx = instances[0].get('vision_patch_idx') input_ids, labels = [], [] has_image = any(inst.get('vision_patches') is not None for inst in instances) has_mask = any(inst.get('masks') is not None for inst in instances) if use_varlen_attn: position_ids, cumulative_len = [], [] assert len(instances) == 1, ( f'If utilizing varlen attention, the batch size should be' f' set to 1, but got {len(instances)}') assert not has_image, 'Currently, it is not configured to ' 'accommodate the use of varlen Attention in multimodal training' patch_nums_per_images = [] vision_start_end = [] vision_patch_indices = [] if has_image: vision_patches = [] else: vision_patches = None if has_mask: masks = [] else: masks = None _vision_indexes_prefix = 0 for example in instances: input_ids.append(torch.LongTensor(example['input_ids'])) labels.append(torch.LongTensor(example['labels'])) patch_nums_per_images.append(example['patch_nums_per_images']) vision_start_end.append(example['vision_start_end']) # compute new multi-batch vision patch indices batch_vision_patch_indices = torch.LongTensor(example['vision_patch_indices']) batch_vision_patch_indices[batch_vision_patch_indices!=NON_VISION_TOKEN] += _vision_indexes_prefix _vision_indexes_prefix = max(torch.max(batch_vision_patch_indices), 0) vision_patch_indices.append(batch_vision_patch_indices) if use_varlen_attn: cumulative_len.append(torch.IntTensor(example['cumulative_len'])) position_ids.append(torch.LongTensor(example['position_ids'])) if has_image: if 'vision_patches' in example.keys(): vision_patches.append(example['vision_patches']) if has_mask: if 'masks' in example.keys() and example['masks'] is not None: masks.append(example['masks']) else: masks.append(None) ori_length = [len(ids) for ids in input_ids] if len(instances) > 1: input_ids = pad_sequence( input_ids, batch_first=True, padding_value=pad_index) labels = pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX) vision_patch_indices = pad_sequence( vision_patch_indices, batch_first=True, padding_value=NON_VISION_TOKEN) else: input_ids = torch.stack(input_ids) labels = torch.stack(labels) vision_patch_indices = torch.stack(vision_patch_indices) if use_varlen_attn: assert input_ids.size(1) % seq_parallel_world_size == 0 attention_mask = None position_ids = torch.stack(position_ids, dim=0) else: # Some tokenizers have the same eos token and pad token, so input_ids # cannot be masked directly based on the pad token id. attention_mask = torch.zeros(input_ids.shape[0], 1, input_ids.shape[1], input_ids.shape[1]).bool() for i, length in enumerate(ori_length): attention_mask[i, 0, :length, :length] = create_single_prefix_mask(vision_start_end[i], length) bs, seq_len = input_ids.shape position_ids = [] for input_id, patch_nums_per_image in zip(input_ids, patch_nums_per_images): position_id = generate_mm_pos_ids_singleit( input_id.cpu().numpy().tolist(), vision_patch_idx, patch_nums_per_image[0], patch_nums_per_image[1]) position_ids.append(position_id) position_ids = torch.stack(position_ids, dim=1) if seq_parallel_world_size > 1: input_ids = pad_for_sequence_parallel(input_ids, pad_index) labels = pad_for_sequence_parallel(labels, IGNORE_INDEX) position_ids = pad_for_sequence_parallel(position_ids, 0) if attention_mask is not None: attention_mask = pad_for_sequence_parallel(attention_mask, 0) if has_image: if len(vision_patches) == 0: vision_patches = None else: vision_patches = torch.cat(vision_patches, dim=0) if use_varlen_attn: max_seqlen = ( cumulative_len[0][1:] - # noqa: W504 cumulative_len[0][:-1]).max().item() data_dict = { 'input_ids': input_ids, 'cumulative_len': cumulative_len, 'position_ids': position_ids, 'labels': labels, 'max_seqlen': max_seqlen, 'vision_patch_indices': vision_patch_indices, 'masks': masks, 'vision_patches': vision_patches, 'patch_nums_per_images': patch_nums_per_images } else: data_dict = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids, 'labels': labels, 'vision_patch_indices': vision_patch_indices, 'masks': masks, 'vision_patches': vision_patches, 'patch_nums_per_images': patch_nums_per_images } if return_hf_format: return data_dict else: return {'data': data_dict, 'data_samples': None} def create_single_prefix_mask(vision_start_end, max_len): if vision_start_end is None: # pure text attn_mask = torch.tril(torch.ones(max_len, max_len)) else: attn_mask = torch.zeros(max_len, max_len) attn_mask[vision_start_end[0]-1:vision_start_end[1]+1, vision_start_end[0]-1:vision_start_end[1]+1] = 1 causal_mask = torch.tril(torch.ones(max_len, max_len)) attn_mask = attn_mask.bool() | causal_mask.bool() return attn_mask