| from abc import ABC, abstractmethod |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModelForCausalLM, CLIPImageProcessor, CLIPVisionConfig, CLIPVisionModel |
| from transformers.generation.utils import GenerateOutput |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.models.mistral.modeling_mistral import MistralForCausalLM, MistralModel |
|
|
| from .configuration_lavy import LlavaMistralConfig |
|
|
|
|
| IGNORE_INDEX = -100 |
| IMAGE_TOKEN_INDEX = -200 |
|
|
|
|
| class CLIPVisionTower(nn.Module): |
| def __init__(self, vision_tower, args, delay_load=False): |
| super().__init__() |
| self.is_loaded = False |
| self.vision_tower_name = vision_tower |
| self.select_layer = args.mm_vision_select_layer |
| self.select_feature = getattr(args, "mm_vision_select_feature", "patch") |
| if not delay_load: |
| self.load_model() |
| else: |
| self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) |
|
|
| def load_model(self, device_map=None): |
| if self.is_loaded: |
| return |
| self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) |
| self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) |
| self.vision_tower.requires_grad_(False) |
| self.is_loaded = True |
|
|
| def feature_select(self, image_forward_outs): |
| image_features = image_forward_outs.hidden_states[self.select_layer] |
| if self.select_feature == "patch": |
| image_features = image_features[:, 1:] |
| elif self.select_feature != "cls_patch": |
| raise ValueError(f"Unexpected select feature: {self.select_feature}") |
| return image_features |
|
|
| @torch.no_grad() |
| def forward(self, images): |
| if not self.is_loaded: |
| self.load_model() |
| if isinstance(images, list): |
| image_features = [] |
| for image in images: |
| image_forward_out = self.vision_tower( |
| image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True |
| ) |
| image_features.append(self.feature_select(image_forward_out).to(image.dtype)) |
| return image_features |
|
|
| image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) |
| return self.feature_select(image_forward_outs).to(images.dtype) |
|
|
| @property |
| def dtype(self): |
| return self.vision_tower.dtype if self.is_loaded else torch.float16 |
|
|
| @property |
| def device(self): |
| return self.vision_tower.device if self.is_loaded else torch.device("cpu") |
|
|
| @property |
| def config(self): |
| return self.vision_tower.config if self.is_loaded else self.cfg_only |
|
|
| @property |
| def hidden_size(self): |
| return self.config.hidden_size |
|
|
|
|
| def build_vision_projector(config): |
| projector_type = getattr(config, "mm_projector_type", "linear") |
| if projector_type == "linear": |
| return nn.Linear(config.mm_hidden_size, config.hidden_size) |
| if projector_type == "mlp2x_gelu": |
| return nn.Sequential( |
| nn.Linear(config.mm_hidden_size, config.hidden_size), |
| nn.GELU(), |
| nn.Linear(config.hidden_size, config.hidden_size), |
| ) |
| raise ValueError(f"Unknown projector type: {projector_type}") |
|
|
|
|
| class LlavaMetaModel: |
| def __init__(self, config): |
| super().__init__(config) |
| if hasattr(config, "mm_vision_tower"): |
| self.vision_tower = CLIPVisionTower(config.mm_vision_tower, args=config, delay_load=True) |
| self.mm_projector = build_vision_projector(config) |
|
|
| def get_vision_tower(self): |
| vision_tower = getattr(self, "vision_tower", None) |
| if isinstance(vision_tower, list): |
| vision_tower = vision_tower[0] |
| return vision_tower |
|
|
|
|
| class LlavaMetaForCausalLM(ABC): |
| @abstractmethod |
| def get_model(self): |
| raise NotImplementedError |
|
|
| def get_vision_tower(self): |
| return self.get_model().get_vision_tower() |
|
|
| def encode_images(self, images): |
| vision_tower = self.get_vision_tower() |
| if vision_tower is not None and not vision_tower.is_loaded: |
| vision_tower.load_model() |
| image_features = vision_tower(images) |
| image_features = self.get_model().mm_projector(image_features) |
| return image_features |
|
|
| def prepare_inputs_labels_for_multimodal( |
| self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes=None |
| ): |
| vision_tower = self.get_vision_tower() |
| if vision_tower is None or images is None or input_ids.shape[1] == 1: |
| return input_ids, position_ids, attention_mask, past_key_values, None, labels |
|
|
| if isinstance(images, list) or images.ndim == 5: |
| if isinstance(images, list): |
| images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] |
| concat_images = torch.cat([image for image in images], dim=0) |
| image_features = self.encode_images(concat_images) |
| split_sizes = [image.shape[0] for image in images] |
| image_features = torch.split(image_features, split_sizes, dim=0) |
| image_features = [x.flatten(0, 1) for x in image_features] |
| else: |
| image_features = self.encode_images(images) |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| else: |
| attention_mask = attention_mask.bool() |
| if position_ids is None: |
| position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
| if labels is None: |
| labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
| original_labels = labels |
| original_attention_mask = attention_mask |
| original_position_ids = position_ids |
|
|
| input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] |
| labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] |
|
|
| new_input_embeds = [] |
| new_labels = [] |
| cur_image_idx = 0 |
| for batch_idx, cur_input_ids in enumerate(input_ids): |
| num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
| if num_images == 0: |
| cur_image_features = image_features[cur_image_idx] |
| cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) |
| cur_input_embeds = torch.cat([cur_input_embeds, cur_image_features[0:0]], dim=0) |
| new_input_embeds.append(cur_input_embeds) |
| new_labels.append(labels[batch_idx]) |
| cur_image_idx += 1 |
| continue |
|
|
| image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] |
| cur_input_ids_noim = [] |
| cur_labels = labels[batch_idx] |
| cur_labels_noim = [] |
| for i in range(len(image_token_indices) - 1): |
| cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
| cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) |
|
|
| split_sizes = [x.shape[0] for x in cur_labels_noim] |
| cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) |
| cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) |
| cur_new_input_embeds = [] |
| cur_new_labels = [] |
| for i in range(num_images + 1): |
| cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
| cur_new_labels.append(cur_labels_noim[i]) |
| if i < num_images: |
| cur_image_features = image_features[cur_image_idx] |
| cur_image_idx += 1 |
| cur_new_input_embeds.append(cur_image_features) |
| cur_new_labels.append( |
| torch.full( |
| (cur_image_features.shape[0],), |
| IGNORE_INDEX, |
| device=cur_labels.device, |
| dtype=cur_labels.dtype, |
| ) |
| ) |
|
|
| new_input_embeds.append(torch.cat([x.to(self.device) for x in cur_new_input_embeds])) |
| new_labels.append(torch.cat(cur_new_labels)) |
|
|
| tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) |
| if tokenizer_model_max_length is not None: |
| new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] |
| new_labels = [x[:tokenizer_model_max_length] for x in new_labels] |
|
|
| max_len = max(x.shape[0] for x in new_input_embeds) |
| batch_size = len(new_input_embeds) |
| new_input_embeds_padded = [] |
| new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) |
| attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) |
| position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) |
|
|
| for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): |
| cur_len = cur_new_embed.shape[0] |
| new_input_embeds_padded.append( |
| torch.cat( |
| [ |
| cur_new_embed, |
| torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), |
| ], |
| dim=0, |
| ) |
| ) |
| if cur_len > 0: |
| new_labels_padded[i, :cur_len] = cur_new_labels |
| attention_mask[i, :cur_len] = True |
| position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) |
|
|
| new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
| if original_labels is None: |
| new_labels_padded = None |
| if original_attention_mask is None: |
| attention_mask = None |
| else: |
| attention_mask = attention_mask.to(dtype=original_attention_mask.dtype) |
| if original_position_ids is None: |
| position_ids = None |
|
|
| return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels_padded |
|
|
|
|
| class LlavaMistralModel(LlavaMetaModel, MistralModel): |
| config_class = LlavaMistralConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
|
|
|
|
| class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): |
| config_class = LlavaMistralConfig |
|
|
| def __init__(self, config): |
| super(MistralForCausalLM, self).__init__(config) |
| self.model = LlavaMistralModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def get_model(self): |
| return self.model |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| images: Optional[torch.FloatTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| image_sizes: Optional[List[List[int]]] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| if images is None: |
| images = pixel_values |
| if inputs_embeds is None: |
| input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels = ( |
| self.prepare_inputs_labels_for_multimodal( |
| input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes |
| ) |
| ) |
|
|
| return super().forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| images: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| image_sizes: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.LongTensor]: |
| if images is None: |
| images = pixel_values |
| position_ids = kwargs.pop("position_ids", None) |
| attention_mask = kwargs.pop("attention_mask", None) |
| if "inputs_embeds" in kwargs: |
| raise NotImplementedError("inputs_embeds is not supported") |
|
|
| if images is not None: |
| inputs, position_ids, attention_mask, _, inputs_embeds, _ = self.prepare_inputs_labels_for_multimodal( |
| inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes |
| ) |
| else: |
| inputs_embeds = self.get_model().embed_tokens(inputs) |
|
|
| return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs) |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
| images = kwargs.pop("images", kwargs.pop("pixel_values", None)) |
| image_sizes = kwargs.pop("image_sizes", None) |
| inputs = super().prepare_inputs_for_generation( |
| input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs |
| ) |
| if images is not None: |
| inputs["images"] = images |
| if image_sizes is not None: |
| inputs["image_sizes"] = image_sizes |
| return inputs |
|
|
|
|
| AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) |
|
|