import torch from xtuner.model import InternVL_V1_5 from typing import List, Optional, Tuple, Union from transformers.modeling_outputs import CausalLMOutputWithPast from torch.nn import CrossEntropyLoss from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, LlamaTokenizer) class InternVL(InternVL_V1_5): def forward(self, data, data_samples=None, mode='loss'): pixel_values = data['pixel_values'] if type(pixel_values) is list or pixel_values.ndim == 5: if type(pixel_values) is list: pixel_values = [ x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values ] # b*n, c, h, w concat_images = torch.cat( [image.to(self.model.vision_model.dtype) for image in pixel_values], dim=0) else: raise NotImplementedError() input_ids = data['input_ids'] position_ids = data['position_ids'] attention_mask = data['attention_mask'] # sum is 0 are text image_flags = torch.sum(concat_images, dim=(1, 2, 3)) != 0 image_flags = image_flags.long() labels = data['labels'] use_cache = False outputs = self._llm_forward( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, image_flags=image_flags, pixel_values=concat_images, labels=labels, use_cache=use_cache, output_hidden_states=True) return outputs def _llm_forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, image_flags: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None \ else self.model.config.use_return_dict image_flags = image_flags.squeeze(-1) # We only added the clone code here to avoid the error. input_embeds = self.model.language_model.get_input_embeddings()( input_ids).clone() vit_embeds = self.model.extract_feature(pixel_values) vit_embeds = vit_embeds.to(input_embeds.dtype) # FIXME: why vit_embeds is float16? vit_embeds = vit_embeds[image_flags == 1] vit_batch_size = pixel_values.shape[0] B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) self._count += 1 input_ids = input_ids.reshape(B * N) selected = (input_ids == self.model.img_context_token_id) try: input_embeds[selected] = vit_embeds.reshape(-1, C) except Exception as e: vit_embeds = vit_embeds.reshape(-1, C) print(f'warning: {e}, input_embeds[selected].shape=' f'{input_embeds[selected].shape}, ' f'vit_embeds.shape={vit_embeds.shape}') n_token = selected.sum() input_embeds[selected] = vit_embeds[:n_token] input_embeds = input_embeds.reshape(B, N, C) outputs = self.model.language_model( inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view( -1, self.model.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits, ) + outputs[1:] return (loss, ) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @torch.no_grad() def generate( self, pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, visual_features: Optional[torch.FloatTensor] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **generate_kwargs, ) -> torch.LongTensor: device = self.model.device assert self.model.img_context_token_id is not None if pixel_values is not None: if visual_features is not None: vit_embeds = visual_features else: if type(pixel_values) is list or pixel_values.ndim == 5: if type(pixel_values) is list: pixel_values = [ x.unsqueeze(0) if x.ndim == 3 else x for x in pixel_values ] # b*n, c, h, w pixel_values = torch.cat( [image.to(self.model.vision_model.dtype) for image in pixel_values], dim=0) vit_embeds = self.model.extract_feature(pixel_values.to(device)) image_flags = torch.sum(pixel_values, dim=(1, 2, 3)) != 0 image_flags = image_flags.long() vit_embeds = vit_embeds[image_flags == 1] input_embeds = self.model.language_model.get_input_embeddings()(input_ids.to(device)) B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.model.img_context_token_id) assert selected.sum() != 0 input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) input_embeds = input_embeds.reshape(B, N, C) else: input_embeds = self.model.language_model.get_input_embeddings()(input_ids) outputs = self.model.language_model.generate( inputs_embeds=input_embeds, attention_mask=attention_mask.to(device), generation_config=generation_config, output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=True, **generate_kwargs, ) return outputs