# Copyright (c) OpenMMLab. All rights reserved. import math import os.path as osp import warnings from collections import OrderedDict from typing import List, Optional import torch.nn.functional as F import torch import torch.nn as nn from accelerate import init_empty_weights from mmengine import print_log from mmengine.config import Config, ConfigDict from mmengine.model import BaseModel from peft import get_peft_model, prepare_model_for_kbit_training from transformers import (AddedToken, AutoConfig, CLIPImageProcessor, PreTrainedModel, CLIPVisionModel, LlamaForCausalLM, LlamaTokenizerFast, LlavaConfig, LlavaForConditionalGeneration, LlavaProcessor) from transformers.integrations import is_deepspeed_zero3_enabled from xtuner.registry import BUILDER from xtuner.model.modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 from xtuner.model.utils import (LoadWoInit, find_all_linear_names, get_peft_model_state_dict, guess_load_checkpoint, make_inputs_require_grad, traverse_dict) from xtuner.utils import IGNORE_INDEX, IMAGE_TOKEN_INDEX from xtuner.tools.utils import get_stop_criteria, is_cn_string from transformers import GenerationConfig from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, PROMPT_TEMPLATE) def convert_state_dict_to_hf(state_dict, mapping): new_state_dict = {} for key, value in state_dict.items(): if key.endswith('.inv_freq'): continue for key_to_modify, new_key in mapping.items(): if key_to_modify in key: key = key.replace(key_to_modify, new_key) new_state_dict[key] = value return new_state_dict class SingleLLaVAModelSFT(BaseModel): def __init__(self, llm, visual_encoder=None, tokenizer=None, freeze_llm=False, freeze_visual_encoder=False, visual_select_layer=-2, pretrained_pth=None, projector_depth=0, llm_lora=None, visual_encoder_lora=None, use_activation_checkpointing=True, max_position_embeddings=None, add_cls_token=False, template=None, ): super().__init__() if tokenizer is not None: self.tokenizer = tokenizer tokenizer_type = self.tokenizer['type'] del self.tokenizer['type'] self.tokenizer = tokenizer_type(**self.tokenizer) self.freeze_llm = freeze_llm self.freeze_visual_encoder = freeze_visual_encoder with LoadWoInit(): if isinstance(llm, dict): llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) self.llm = self._build_from_cfg_or_module(llm) if visual_encoder is not None: self.visual_encoder = self._build_from_cfg_or_module( visual_encoder) else: self.visual_encoder = None self.llm.config.use_cache = False # dispatch_modules(self.llm) self.projector_depth = projector_depth # projector_config = ProjectorConfig( # visual_hidden_size=self.visual_encoder.config.hidden_size, # llm_hidden_size=self.llm.config.hidden_size, # depth=self.projector_depth) # self.projector = ProjectorModel(projector_config).to( # self.visual_encoder.dtype) self.projector = None if self.freeze_llm: self.llm.requires_grad_(False) if self.freeze_visual_encoder: if self.visual_encoder is not None: self.visual_encoder.requires_grad_(False) if use_activation_checkpointing: # For backward compatibility if hasattr(self.llm, 'enable_input_require_grads'): self.llm.enable_input_require_grads() else: self.llm.get_input_embeddings().register_forward_hook( make_inputs_require_grad) # enable gradient (activation) checkpointing for memory efficiency self.gradient_checkpointing_enable() self.use_llm_lora = llm_lora is not None self.use_visual_encoder_lora = visual_encoder_lora is not None if self.use_llm_lora: self._prepare_llm_for_lora(llm_lora, use_activation_checkpointing) if self.use_visual_encoder_lora: self._prepare_visual_encoder_for_lora( visual_encoder_lora, use_activation_checkpointing) if pretrained_pth is not None: pretrained_state_dict = guess_load_checkpoint(pretrained_pth) self.load_state_dict(pretrained_state_dict, strict=False) print_log(f'Load pretrained weight from {pretrained_pth}', 'current') self.visual_select_layer = visual_select_layer self._is_init = True self.is_first_iter = True self.add_cls_token = add_cls_token self.template = template def _parse_lora_config(self, lora_config): if isinstance(lora_config, dict) or isinstance( lora_config, Config) or isinstance(lora_config, ConfigDict): lora_config = BUILDER.build(lora_config) return lora_config def _prepare_llm_for_lora(self, lora_config, use_activation_checkpointing=True): lora_config = self._parse_lora_config(lora_config) self.llm = prepare_model_for_kbit_training( self.llm, use_activation_checkpointing) if lora_config.target_modules is None: modules = find_all_linear_names(self.llm) lora_config.target_modules = modules self.llm = get_peft_model(self.llm, lora_config) def _prepare_visual_encoder_for_lora(self, lora_config, use_activation_checkpointing=True): lora_config = self._parse_lora_config(lora_config) if lora_config.target_modules is None: modules = find_all_linear_names(self.visual_encoder) lora_config.target_modules = modules self.visual_encoder = get_peft_model(self.visual_encoder, lora_config) def gradient_checkpointing_enable(self): self.activation_checkpointing_enable() def activation_checkpointing_enable(self): self.llm.gradient_checkpointing_enable() if self.visual_encoder is not None: self.visual_encoder.gradient_checkpointing_enable() if self.projector is not None: self.projector.gradient_checkpointing_enable() def gradient_checkpointing_disable(self): self.activation_checkpointing_disable() def activation_checkpointing_disable(self): self.llm.gradient_checkpointing_disable() if self.visual_encoder is not None: self.visual_encoder.gradient_checkpointing_disable() if self.projector is not None: self.projector.gradient_checkpointing_disable() def init_weights(self): pass def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) to_return = OrderedDict() # Step 1. visual_encoder if self.use_visual_encoder_lora: to_return.update( get_peft_model_state_dict( self.visual_encoder, state_dict=state_dict)) elif not self.freeze_visual_encoder: to_return.update({ k: v for k, v in state_dict.items() if 'visual_encoder.' in k }) # Step 2. LLM if self.use_llm_lora: to_return.update( get_peft_model_state_dict(self.llm, state_dict=state_dict)) elif not self.freeze_llm: to_return.update( {k: v for k, v in state_dict.items() if 'llm.' in k}) # Step 3. Projector to_return.update( {k: v for k, v in state_dict.items() if 'projector.' in k}) return to_return @staticmethod def _prepare_for_long_context_training(cfg, llm_cfg, max_position_embeddings): orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) if orig_rope_scaling is None: orig_rope_scaling = {'factor': 1} orig_rope_scaling_factor = orig_rope_scaling[ 'factor'] if 'factor' in orig_rope_scaling.keys() else 1 orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) if orig_ctx_len: orig_ctx_len *= orig_rope_scaling_factor if max_position_embeddings > orig_ctx_len: scaling_factor = float( math.ceil(max_position_embeddings / orig_ctx_len)) llm_cfg.rope_scaling = { 'type': 'linear', 'factor': scaling_factor } # hardcode for internlm2 llm_cfg.attn_implementation = 'flash_attention_2' cfg.config = llm_cfg return cfg, llm_cfg @staticmethod def _prepare_for_flash_attn(cfg, llm_cfg): cls_name = type(llm_cfg).__name__ SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', 'Starcoder2Config', 'Starcoder2Config', 'Phi3Config') SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', 'MistralConfig', 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', 'Starcoder2Config', 'Starcoder2Config', 'Phi3Config') torch_dtype = torch.bfloat16 if ( torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ else torch.float16 if getattr(cfg, 'attn_implementation', None) is not None: # Flash Attention 2.0 only supports torch.float16 and # torch.bfloat16 dtypes if cfg.attn_implementation == 'flash_attention_2': cfg.torch_dtype = torch_dtype elif SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: cfg.torch_dtype = torch_dtype cfg.attn_implementation = 'flash_attention_2' elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: cfg.attn_implementation = 'sdpa' return cfg, llm_cfg @staticmethod def _prepare_for_qlora_zero3(cfg): if (not is_deepspeed_zero3_enabled()) or (not hasattr( cfg, 'quantization_config')): return cfg torch_dtype = torch.bfloat16 if ( torch.cuda.is_available() and torch.cuda.is_bf16_supported()) \ else torch.float16 cfg.torch_dtype = torch_dtype quantization_config = cfg.quantization_config quantization_config.bnb_4bit_compute_dtype = torch_dtype quantization_config.bnb_4bit_quant_storage = torch_dtype return cfg def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): cfg = self._prepare_for_qlora_zero3(cfg) pretrained_model_name_or_path = cfg.pretrained_model_name_or_path llm_cfg = AutoConfig.from_pretrained( pretrained_model_name_or_path, trust_remote_code=True) cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) if max_position_embeddings is not None: cfg, llm_cfg = self._prepare_for_long_context_training( cfg, llm_cfg, max_position_embeddings) return cfg def _build_from_cfg_or_module(self, cfg_or_mod): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod elif isinstance(cfg_or_mod, dict): traverse_dict(cfg_or_mod) return BUILDER.build(cfg_or_mod) else: raise NotImplementedError def forward(self, data, data_samples=None, mode='loss'): if self.is_first_iter: # hardcode for qlora DeepSpeed ZeRO3, put buffers and QuantState to # device # Only required in `LLaVAModel` . # We do not need this in `SupervisedFinetune` . self.to(data['input_ids'].device) self.is_first_iter = False # if 'pixel_values' in data: # # no visual encoder # # visual_outputs = self.visual_encoder( # # data['pixel_values'].to(self.visual_encoder.dtype), # # output_hidden_states=True) # # pixel_values = self.projector( # # visual_outputs.hidden_states[self.visual_select_layer][:, 1:]) # # data['pixel_values'] = pixel_values # # only merge the image and text tokens. # data = prepare_inputs_labels_for_multimodal_solo(llm=self.llm, tokenizer=self.tokenizer, **data, # add_CLS=self.add_cls_token) data = prepare_inputs_labels_for_multimodal_solo(llm=self.llm, tokenizer=self.tokenizer, **data, add_CLS=self.add_cls_token) if mode == 'loss': loss = self.compute_loss(data, data_samples) if torch.isnan(loss["loss"]): print("loss nan here") return loss elif mode == 'predict': return self.predict(data, data_samples) elif mode == 'tensor': return self._forward(data, data_samples) else: raise NotImplementedError def _forward(self, data, data_samples=None): outputs = self.llm(**data) return outputs def predict(self, data, data_samples=None): outputs = self.llm(**data) logits_dict = [{'logits': logits} for logits in outputs.logits] return logits_dict def compute_loss(self, data, data_samples=None): outputs = self.llm(**data) loss_dict = {'loss': outputs.loss} return loss_dict def __getattr__(self, name: str): # return super().__getattr__(name) try: return super().__getattr__(name) except AttributeError: return getattr(self.llm, name) def to_hf(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}, save_format='xtuner', **kwargs): if save_format == 'xtuner': self.to_xtuner_llava(cfg, save_dir, fp32, save_pretrained_kwargs) elif save_format == 'huggingface': self.to_huggingface_llava(cfg, save_dir, fp32, save_pretrained_kwargs) elif save_format == 'official': self.to_official_llava(cfg, save_dir, fp32, save_pretrained_kwargs) else: raise NotImplementedError def to_xtuner_llava(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}): # LLM self.llm.config.use_cache = True if not fp32: print_log('Convert LLM to float16', 'current') self.llm.half() if self.use_llm_lora: llm_path = osp.join(save_dir, 'llm_adapter') print_log(f'Saving LLM adapter to {llm_path}', 'current') self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) elif not self.freeze_llm: llm_path = save_dir print_log(f'Saving LLM tokenizer to {llm_path}', 'current') tokenizer = BUILDER.build(cfg.tokenizer) tokenizer.save_pretrained(llm_path, **save_pretrained_kwargs) print_log(f'Saving LLM to {llm_path}', 'current') self.llm.save_pretrained(llm_path, **save_pretrained_kwargs) self.llm.config.use_cache = False # Visual Encoder if self.use_visual_encoder_lora: visual_encoder_path = osp.join(save_dir, 'visual_encoder_adapter') print_log( f'Saving visual_encoder adapter to {visual_encoder_path}', 'current') self.visual_encoder.save_pretrained(visual_encoder_path, **save_pretrained_kwargs) elif not self.freeze_visual_encoder: visual_encoder_path = osp.join(save_dir, 'visual_encoder') print_log( 'Saving visual_encoder image_processor to' f'{visual_encoder_path}', 'current') image_processor = BUILDER.build(cfg.image_processor) image_processor.save_pretrained(visual_encoder_path, **save_pretrained_kwargs) print_log(f'Saving visual_encoder to {visual_encoder_path}', 'current') self.visual_encoder.save_pretrained(visual_encoder_path, **save_pretrained_kwargs) # Projector projector_path = osp.join(save_dir, 'projector') print_log(f'Saving projector to {projector_path}', 'current') self.projector.save_pretrained(projector_path, **save_pretrained_kwargs) def to_huggingface_llava(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}): LLM_MAPPING = { 'model': 'language_model.model', 'lm_head': 'language_model.lm_head', } VIT_MAPPING = { 'vision_model': 'vision_tower.vision_model', } PROJECTOR_MAPPING = { 'model.0': 'multi_modal_projector.linear_1', 'model.2': 'multi_modal_projector.linear_2', } assert getattr(self.llm, 'hf_quantizer', None) is None, \ 'This conversion format does not support quantized LLM.' # get state_dict llm = self.llm if self.use_llm_lora: llm = self.llm.merge_and_unload() llm.config.use_cache = True if not fp32: print_log('Convert LLM to float16', 'current') llm.half() assert isinstance(llm, LlamaForCausalLM), \ 'This conversion format only supports LlamaForCausalLM.' llm_state_dict = llm.state_dict() llm_state_dict = convert_state_dict_to_hf(llm_state_dict, LLM_MAPPING) need_visual_encoder = (not self.freeze_visual_encoder or self.use_visual_encoder_lora) visual_encoder = self.visual_encoder if self.use_visual_encoder_lora: visual_encoder = self.visual_encoder.merge_and_unload() assert isinstance(visual_encoder, CLIPVisionModel), \ 'This conversion format only supports CLIPVisionModel.' if need_visual_encoder: visual_encoder_state_dict = visual_encoder.state_dict() visual_encoder_state_dict = convert_state_dict_to_hf( visual_encoder_state_dict, VIT_MAPPING) else: visual_encoder_state_dict = {} projector_state_dict = self.projector.state_dict() projector_state_dict = convert_state_dict_to_hf( projector_state_dict, PROJECTOR_MAPPING) state_dict = { **projector_state_dict, **llm_state_dict, **visual_encoder_state_dict } # init model text_config = llm.config vision_config = visual_encoder.config config = LlavaConfig( text_config=text_config, vision_config=vision_config, attn_implementation='eager') with init_empty_weights(): with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', message='.*non-meta.*', category=UserWarning) model = LlavaForConditionalGeneration(config) model.load_state_dict(state_dict, strict=True, assign=True) # processor cfg.tokenizer.type = LlamaTokenizerFast.from_pretrained tokenizer = BUILDER.build(cfg.tokenizer) tokenizer.add_tokens( AddedToken(DEFAULT_IMAGE_TOKEN, special=True, normalized=False), special_tokens=True) tokenizer.add_special_tokens({'pad_token': ''}) image_processor = BUILDER.build(cfg.image_processor) assert isinstance(image_processor, CLIPImageProcessor), \ 'This conversion format only supports CLIPImageProcessor.' processor = LlavaProcessor( tokenizer=tokenizer, image_processor=image_processor) # Pad to 64 for performance reasons pad_shape = 64 pre_expansion_embeddings = \ model.language_model.model.embed_tokens.weight.data mu = torch.mean(pre_expansion_embeddings, dim=0).float() n = pre_expansion_embeddings.size()[0] sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n dist = torch.distributions.multivariate_normal.MultivariateNormal( mu, covariance_matrix=1e-5 * sigma) # We add an image token so we need to resize the model ori_vocab_size = config.text_config.vocab_size tokenizer_vocab_size = tokenizer.encode('')[-1] added_token = tokenizer_vocab_size - ori_vocab_size if added_token > 0: model.resize_token_embeddings(ori_vocab_size + added_token, pad_shape) model.language_model.model.embed_tokens.weight.data[ ori_vocab_size:] = torch.stack( tuple( dist.sample() for _ in range(model.language_model.model.embed_tokens. weight.data[ori_vocab_size:].shape[0])), dim=0, ) model.language_model.lm_head.weight.data[ ori_vocab_size:] = torch.stack( tuple(dist.sample() for _ in range(model.language_model.lm_head.weight. data[ori_vocab_size:].shape[0])), dim=0, ) model.config.image_token_index = tokenizer.encode( DEFAULT_IMAGE_TOKEN)[-1] model.config.pad_token_id = tokenizer.encode('')[-1] # save print_log(f'Saving to {save_dir}', 'current') model.save_pretrained(save_dir, **save_pretrained_kwargs) processor.save_pretrained(save_dir, **save_pretrained_kwargs) def to_official_llava(self, cfg, save_dir, fp32=False, save_pretrained_kwargs={}): VIT_MAPPING = { 'vision_model': 'model.vision_tower.vision_tower.vision_model', } PROJECTOR_MAPPING = { 'model.0': 'model.mm_projector.0', 'model.2': 'model.mm_projector.2', } try: from llava.model import LlavaConfig, LlavaLlamaForCausalLM except ImportError: raise ImportError( 'Please install llava with ' '`pip install git+https://github.com/haotian-liu/LLaVA.git ' '--no-deps`.') assert getattr(self.llm, 'hf_quantizer', None) is None, \ 'This conversion format does not support quantized LLM.' # get state_dict llm = self.llm if self.use_llm_lora: llm = self.llm.merge_and_unload() llm.config.use_cache = True if not fp32: print_log('Convert LLM to float16', 'current') llm.half() assert isinstance(llm, LlamaForCausalLM), \ 'This conversion format only supports LlamaForCausalLM.' llm_state_dict = llm.state_dict() need_visual_encoder = (not self.freeze_visual_encoder or self.use_visual_encoder_lora) visual_encoder = self.visual_encoder if self.use_visual_encoder_lora: visual_encoder = self.visual_encoder.merge_and_unload() assert isinstance(visual_encoder, CLIPVisionModel), \ 'This conversion format only supports CLIPVisionModel.' if need_visual_encoder: visual_encoder_state_dict = visual_encoder.state_dict() visual_encoder_state_dict = convert_state_dict_to_hf( visual_encoder_state_dict, VIT_MAPPING) else: visual_encoder_state_dict = {} projector_state_dict = self.projector.state_dict() projector_state_dict = convert_state_dict_to_hf( projector_state_dict, PROJECTOR_MAPPING) state_dict = { **projector_state_dict, **llm_state_dict, **visual_encoder_state_dict } # init model tokenizer = BUILDER.build(cfg.tokenizer) image_processor = BUILDER.build(cfg.image_processor) assert isinstance(image_processor, CLIPImageProcessor), \ 'This conversion format only supports CLIPImageProcessor.' llava_config_dict = llm.config.__dict__.copy() llava_config_dict.update( dict( image_aspect_ratio='pad', mm_hidden_size=visual_encoder.config.hidden_size, mm_projector_type=f'mlp{self.projector_depth}x_gelu', mm_use_im_patch_token=False, mm_use_im_start_end=False, mm_vision_select_feature='patch', mm_vision_select_layer=self.visual_select_layer, mm_vision_tower=visual_encoder.config.name_or_path, unfreeze_mm_vision_tower=need_visual_encoder, model_type='llava', use_cache=True, use_mm_proj=True)) llava_config = LlavaConfig(**llava_config_dict) with init_empty_weights(): with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', message='.*non-meta.*', category=UserWarning) model = LlavaLlamaForCausalLM(llava_config) model.load_state_dict(state_dict, strict=True, assign=True) # save print_log(f'Saving to {save_dir}', 'current') model.save_pretrained(save_dir, **save_pretrained_kwargs) image_processor.save_pretrained(save_dir, **save_pretrained_kwargs) tokenizer.save_pretrained(save_dir, **save_pretrained_kwargs) def preparing_for_generation(self, metainfo): # set stop criteria and generation configs for model assert hasattr(self, 'tokenizer'), "The Model does not have the tokenizer!!!" self.bot_name = 'BOT' # template = PROMPT_TEMPLATE['mistral'] # self.template = template stop_words = [] stop_words += self.template.get('STOP_WORDS', []) stop_criteria = get_stop_criteria( tokenizer=self.tokenizer, stop_words=stop_words) self.stop_criteria = stop_criteria default_generation_kwargs = dict( # keep the max tokens max_new_tokens=2048, do_sample=False, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=( self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id ), ) default_generation_kwargs.update(metainfo.get('generation_kwargs', {})) self.gen_config = GenerationConfig(**default_generation_kwargs) self.init_prediction_config = True return def predict_forward( self, pixel_values, text_prompts, **kwargs): # pixel_values: image tensor # text_prompts: question without template text_prompts = text_prompts.replace('\n', '').strip() # print("text_prompt: ", text_prompts) assert self.init_prediction_config, "Please set prediction configs using self.preparing_for_generation()" # add template for text input_text = '' input_text += self.template['INSTRUCTION'].format( input=text_prompts, round=1, bot_name=self.bot_name) # input_text = '' + input_text cur_encode = self.tokenizer.encode(input_text) # chunk_encode = [] # for idx, chunk in enumerate(input_text.split(DEFAULT_IMAGE_TOKEN)): # # if idx == 0: # # cur_encode = self.tokenizer.encode(chunk) # # else: # # cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False) # cur_encode = self.tokenizer.encode(chunk) # chunk_encode.append(cur_encode) # # assert len(chunk_encode) == 2 ids = [IMAGE_TOKEN_INDEX] ids.extend(cur_encode) # for idx, cur_chunk_encode in enumerate(chunk_encode): # ids.extend(cur_chunk_encode) # if idx != len(chunk_encode) - 1: # ids.append(IMAGE_TOKEN_INDEX) ids = torch.tensor(ids).cuda().unsqueeze(0) pixel_values = pixel_values.cuda().unsqueeze(0) # print(torch.max(pixel_values), ' ', torch.min(pixel_values)) h, w = pixel_values.shape[-2:] if max(h, w) > 1024: if h > w: h_new = 1024 w_new = int(w * h_new / h) w_new = pad_32(w_new) else: w_new = 1024 h_new = int(h * w_new / w) h_new = pad_32(h_new) else: h_new = pad_32(h) w_new = pad_32(w) dtype = pixel_values.dtype pixel_values = F.interpolate(pixel_values.to(torch.float32), size=(h_new, w_new), mode='bilinear', align_corners=False).to(dtype) mm_inputs = prepare_inputs_labels_for_multimodal_solo( llm=self.llm, tokenizer=self.tokenizer, input_ids=ids, pixel_values=pixel_values) if 'input_ids' in mm_inputs.keys() and mm_inputs['input_ids'] is not None: inp_length = mm_inputs['input_ids'].shape[1] else: inp_length = 0 # print(inp_length) generate_output = self.llm.generate( **mm_inputs, generation_config=self.gen_config, streamer=None, bos_token_id=self.tokenizer.bos_token_id, stopping_criteria=self.stop_criteria, output_hidden_states=False, return_dict_in_generate=True ) # predict = self.tokenizer.decode( # generate_output.sequences[0], skip_special_tokens=False) # print(predict) # print('\n\n\n', '--------------------------------') predict = self.tokenizer.decode( generate_output.sequences[0][inp_length:], skip_special_tokens=True).strip() print(predict) return {'prediction': predict} def prepare_inputs_labels_for_multimodal_solo( llm: PreTrainedModel, tokenizer=None, input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, add_CLS: bool = False, ): # (1, 3, 1024, 1024) ori_input_ids = input_ids # no image data if pixel_values is None: return { 'input_ids': input_ids, 'position_ids': position_ids, 'attention_mask': attention_mask, 'past_key_values': past_key_values, 'inputs_embeds': None, 'labels': labels } # image data _labels = labels # (B, N) _position_ids = position_ids _attention_mask = attention_mask vision_patch_indices = [] vision_patches = [] visual_tokens = [] patch_size = 32 NON_VISION_TOKEN = -1 if isinstance(pixel_values, torch.Tensor): assert pixel_values.shape[0] == 1 pixel_values = pixel_values[0] patches = pixel_values.unfold(1, patch_size, patch_size) \ .unfold(2, patch_size, patch_size) patches = patches.permute(1, 2, 0, 3, 4).contiguous() # -> (N_H_PATCHES, N_W_PATCHES, C, PATCH_H, PATCH_W) n_rows, n_cols = patches.shape[:2] n_patches = n_rows * n_cols patches = patches.view(n_patches, -1) img_tokens = [""] cur_patch_indices = [NON_VISION_TOKEN] for row_idx in range(n_rows): for col_idx in range(n_cols): if row_idx != 0 and col_idx == 0: # when new row starts img_tokens.append(f"") cur_patch_indices.append(NON_VISION_TOKEN) img_tokens.append(f"") cur_patch_indices.append(row_idx * n_cols + col_idx) # note we use for consistency. img_tokens.append("") cur_patch_indices.append(NON_VISION_TOKEN) if add_CLS: # add cls token to align pretrain img_tokens.append("<|vis_cls|>") cur_patch_indices.append(NON_VISION_TOKEN) cur_tokens = torch.Tensor(tokenizer.convert_tokens_to_ids(img_tokens, )) assert len(cur_tokens) == len(cur_patch_indices), f"{len(cur_tokens)} != {len(cur_patch_indices)}" vision_patch_indices.append(torch.Tensor(cur_patch_indices).to(ori_input_ids)) vision_patches.append(patches.to(pixel_values.dtype)) visual_tokens.append(cur_tokens) else: for pixel_value in pixel_values: per_image_patches = pixel_value.unfold(1, patch_size, patch_size) \ .unfold(2, patch_size, patch_size) per_image_patches = per_image_patches.permute(1, 2, 0, 3, 4).contiguous() # -> (N_H_PATCHES, N_W_PATCHES, C, PATCH_H, PATCH_W) n_rows, n_cols = per_image_patches.shape[:2] n_patches = n_rows * n_cols per_image_patches = per_image_patches.view(n_patches, -1) img_tokens = [""] cur_patch_indices = [NON_VISION_TOKEN] for row_idx in range(n_rows): for col_idx in range(n_cols): if row_idx != 0 and col_idx == 0: # when new row starts img_tokens.append(f"") cur_patch_indices.append(NON_VISION_TOKEN) img_tokens.append(f"") cur_patch_indices.append(row_idx * n_cols + col_idx) # note we use img_tokens.append("") cur_patch_indices.append(NON_VISION_TOKEN) if add_CLS: # add cls token to align pretrain img_tokens.append("<|vis_cls|>") cur_patch_indices.append(NON_VISION_TOKEN) cur_tokens = torch.Tensor(tokenizer.convert_tokens_to_ids(img_tokens, )) assert len(cur_tokens) == len(cur_patch_indices), f"{len(cur_tokens)} != {len(cur_patch_indices)}" vision_patch_indices.append(torch.Tensor(cur_patch_indices).to(ori_input_ids)) vision_patches.append(per_image_patches.to(pixel_value.dtype)) visual_tokens.append(cur_tokens) # for support multi batch prefix_num = 0 for i in range(len(vision_patch_indices)): vision_patch_indices[i] = vision_patch_indices[i] + prefix_num prefix_num += len(vision_patches[i]) vision_patches = torch.cat(vision_patches, dim=0) if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange( 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) # remove the padding using attention_mask -- TODO: double check input_ids = [ cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) ] labels = [ cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask) ] new_inputs_ids = [] new_vision_ids = [] new_labels = [] cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() if num_images == 0: new_inputs_ids.append(cur_input_ids) new_labels.append(labels[batch_idx]) new_vision_ids.append(cur_input_ids * 0 + NON_VISION_TOKEN) cur_image_idx += 1 continue need_replace = cur_input_ids == IMAGE_TOKEN_INDEX num_replace = need_replace.sum() image_token_indices = [-1] + torch.where( need_replace)[0].tolist() + [ cur_input_ids.shape[0] ] cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]]) cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]]) cur_new_inputs_ids = [] cur_new_labels = [] cur_new_vision_ids = [] for i in range(num_replace + 1): cur_new_inputs_ids.append(cur_input_ids_noim[i]) cur_new_vision_ids.append(cur_input_ids_noim[i] * 0 + NON_VISION_TOKEN) cur_new_labels.append(cur_labels_noim[i]) if i < num_replace: # image cur_vision_tokens = visual_tokens[cur_image_idx].to(ori_input_ids) cur_new_inputs_ids.append(cur_vision_tokens) cur_new_vision_ids.append(vision_patch_indices[cur_image_idx]) cur_new_labels.append( torch.full((cur_vision_tokens.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) cur_image_idx += 1 cur_new_inputs_ids = torch.cat(cur_new_inputs_ids) cur_new_vision_ids = torch.cat(cur_new_vision_ids) cur_new_labels = torch.cat(cur_new_labels) new_inputs_ids.append(cur_new_inputs_ids) new_vision_ids.append(cur_new_vision_ids) new_labels.append(cur_new_labels) # Combine them max_len = max(x.shape[0] for x in new_inputs_ids) batch_size = len(new_inputs_ids) new_inputs_ids_padded = [] new_vision_ids_padded = [] new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) for i, (cur_new_id, cur_new_labels) in enumerate(zip(new_inputs_ids, new_labels)): # print(i, new_vision_ids) cur_vision_id = new_vision_ids[i] cur_len = cur_new_id.shape[0] new_inputs_ids_padded.append( torch.cat((cur_new_id, torch.zeros((max_len - cur_len,), dtype=cur_new_id.dtype, device=cur_new_id.device)), dim=0)) new_vision_ids_padded.append( torch.cat((cur_vision_id, torch.zeros((max_len - cur_len,), dtype=cur_new_id.dtype, device=cur_new_id.device) + NON_VISION_TOKEN), dim=0)) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange( 0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_inputs_ids = torch.stack(new_inputs_ids_padded, dim=0) new_vision_ids = torch.stack(new_vision_ids_padded, dim=0) if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None vpatch_id = tokenizer.encode("", add_special_tokens=False)[0] vpatch_indices = new_inputs_ids.clone().detach() vpatch_indices[vpatch_indices != vpatch_id] = NON_VISION_TOKEN if vision_patches is not None: assert vision_patches.size(0) == (vpatch_indices == vpatch_id).sum().item(), \ f"number of vision patches is the the same as indicated in indices: {vision_patches.size(0)} vs {(vpatch_indices == vpatch_id).sum().item()}." vpatch_indices[vpatch_indices == vpatch_id] = torch.arange((vpatch_indices == vpatch_id).sum(), device=vpatch_indices.device) return { 'input_ids': new_inputs_ids, 'position_ids': position_ids, 'attention_mask': attention_mask, 'past_key_values': past_key_values, # 'inputs_embeds': None, 'labels': new_labels, 'vision_patch_indices': vpatch_indices, # new_vision_ids, 'vision_patches': vision_patches, # only 1 image } def pad_32(val): if val % 32 == 0: return val else: return (val // 32 + 1) * 32