from abc import ABC, abstractmethod from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from transformers import AutoModelForCausalLM, CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel from transformers.generation.utils import GenerateOutput from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.mistral.modeling_mistral import MistralForCausalLM, MistralModel from .configuration_lavy import LlavaMistralConfig IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 class CLIPVisionTower(nn.Module): def __init__(self, vision_tower, args, delay_load=False): super().__init__() self.is_loaded = False self.vision_tower_name = vision_tower self.select_layer = args.mm_vision_select_layer self.select_feature = getattr(args, "mm_vision_select_feature", "patch") if not delay_load: self.load_model() else: self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) def load_model(self, device_map=None): if self.is_loaded: return self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) self.vision_tower.requires_grad_(False) self.is_loaded = True def feature_select(self, image_forward_outs): image_features = image_forward_outs.hidden_states[self.select_layer] if self.select_feature == "patch": image_features = image_features[:, 1:] elif self.select_feature != "cls_patch": raise ValueError(f"Unexpected select feature: {self.select_feature}") return image_features @torch.no_grad() def forward(self, images): if not self.is_loaded: self.load_model() if isinstance(images, list): image_features = [] for image in images: image_forward_out = self.vision_tower( image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True ) image_features.append(self.feature_select(image_forward_out).to(image.dtype)) return image_features image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) return self.feature_select(image_forward_outs).to(images.dtype) @property def dtype(self): return self.vision_tower.dtype if self.is_loaded else torch.float16 @property def device(self): return self.vision_tower.device if self.is_loaded else torch.device("cpu") @property def config(self): return self.vision_tower.config if self.is_loaded else self.cfg_only @property def hidden_size(self): return self.config.hidden_size def build_vision_projector(config): projector_type = getattr(config, "mm_projector_type", "linear") if projector_type == "linear": return nn.Linear(config.mm_hidden_size, config.hidden_size) if projector_type == "mlp2x_gelu": return nn.Sequential( nn.Linear(config.mm_hidden_size, config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, config.hidden_size), ) raise ValueError(f"Unknown projector type: {projector_type}") class LlavaMetaModel: def __init__(self, config): super().__init__(config) if hasattr(config, "mm_vision_tower"): self.vision_tower = CLIPVisionTower(config.mm_vision_tower, args=config, delay_load=True) self.mm_projector = build_vision_projector(config) def get_vision_tower(self): vision_tower = getattr(self, "vision_tower", None) if isinstance(vision_tower, list): vision_tower = vision_tower[0] return vision_tower class LlavaMetaForCausalLM(ABC): @abstractmethod def get_model(self): raise NotImplementedError def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): vision_tower = self.get_vision_tower() if vision_tower is not None and not vision_tower.is_loaded: vision_tower.load_model() image_features = vision_tower(images) image_features = self.get_model().mm_projector(image_features) return image_features def prepare_inputs_labels_for_multimodal( self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels if isinstance(images, list) or images.ndim == 5: if isinstance(images, list): images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) original_labels = labels original_attention_mask = attention_mask original_position_ids = position_ids input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] new_input_embeds = [] new_labels = [] cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() if num_images == 0: cur_image_features = image_features[cur_image_idx] cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = torch.cat([cur_input_embeds, cur_image_features[0:0]], dim=0) new_input_embeds.append(cur_input_embeds) new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) split_sizes = [x.shape[0] for x in cur_labels_noim] cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) cur_new_input_embeds = [] cur_new_labels = [] for i in range(num_images + 1): cur_new_input_embeds.append(cur_input_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) if i < num_images: cur_image_features = image_features[cur_image_idx] cur_image_idx += 1 cur_new_input_embeds.append(cur_image_features) cur_new_labels.append( torch.full( (cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype, ) ) new_input_embeds.append(torch.cat([x.to(self.device) for x in cur_new_input_embeds])) new_labels.append(torch.cat(cur_new_labels)) tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) if tokenizer_model_max_length is not None: new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] new_labels = [x[:tokenizer_model_max_length] for x in new_labels] max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): cur_len = cur_new_embed.shape[0] new_input_embeds_padded.append( torch.cat( [ cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), ], dim=0, ) ) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) if original_labels is None: new_labels_padded = None if original_attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=original_attention_mask.dtype) if original_position_ids is None: position_ids = None return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels_padded class LlavaMistralModel(LlavaMetaModel, MistralModel): config_class = LlavaMistralConfig def __init__(self, config): super().__init__(config) class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): config_class = LlavaMistralConfig def __init__(self, config): super(MistralForCausalLM, self).__init__(config) self.model = LlavaMistralModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: if images is None: images = pixel_values if inputs_embeds is None: input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels = ( self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes ) ) return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: if images is None: images = pixel_values position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("inputs_embeds is not supported") if images is not None: inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal( inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes ) else: inputs_embeds = self.get_model().embed_tokens(inputs) return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", kwargs.pop("pixel_values", None)) image_sizes = kwargs.pop("image_sizes", None) inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs ) if images is not None: inputs["images"] = images if image_sizes is not None: inputs["image_sizes"] = image_sizes return inputs AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)