| from typing import Any, override |
|
|
| import torch |
| import torch.nn as nn |
| from torch import FloatTensor, LongTensor, Tensor |
| from transformers import AutoModel, LlamaForCausalLM, LlamaModel, PreTrainedModel |
|
|
| from .configuration_vlm import VLMConfig |
| from .connectors import Connector, connector_map |
|
|
|
|
| class VLM(LlamaModel): |
| config_class = VLMConfig |
|
|
| @override |
| def __init__(self, config): |
| super().__init__(config) |
| self.vision_model = self._build_vision_model(config) |
| self.connector = self._build_connector(config) |
|
|
| def _build_vision_model(self, config: Any) -> PreTrainedModel: |
| vision_config = config.vision_config |
| visual_encoder: PreTrainedModel = AutoModel.from_pretrained( |
| vision_config.hf_name, |
| trust_remote_code=True, |
| ) |
| if getattr(visual_encoder, "vision_model", None): |
| visual_encoder = visual_encoder.vision_model |
| return visual_encoder |
|
|
| def _build_connector(self, config: Any) -> Connector: |
| connector_config = config.connector_config |
| connector_class = connector_map.get(connector_config.type) |
| if not connector_class: |
| raise ValueError(f"Unsupported connector type: {connector_config.type}") |
| return connector_class( |
| connector_config, |
| self.config.vision_config.hidden_size, |
| config.hidden_size, |
| ) |
|
|
|
|
| class VLMForCausalLM(LlamaForCausalLM): |
| config_class = VLMConfig |
|
|
| @override |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = VLM(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| @override |
| def forward( |
| self: Any, |
| input_ids: Tensor | None = None, |
| inputs_embeds: Tensor | None = None, |
| attention_mask: Tensor | None = None, |
| position_ids: LongTensor | None = None, |
| past_key_values: list[FloatTensor] | None = None, |
| labels: LongTensor | None = None, |
| use_cache: bool | None = None, |
| output_attentions: bool | None = None, |
| output_hidden_states: bool | None = None, |
| images: FloatTensor | None = None, |
| image_sizes: list[list[int]] | None = None, |
| return_dict: bool | None = None, |
| ) -> torch.Tensor: |
| if inputs_embeds is None: |
| (input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels) = ( |
| self.prepare_inputs_labels_for_multimodal( |
| input_ids, |
| position_ids, |
| attention_mask, |
| past_key_values, |
| labels, |
| images, |
| ) |
| ) |
| return super().forward( |
| input_ids=input_ids, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| @override |
| def generate( |
| self: Any, |
| inputs: Tensor | None = None, |
| images: FloatTensor | None = None, |
| image_sizes: list[list[int]] | None = None, |
| **kwargs: Any, |
| ) -> Any: |
| position_ids = kwargs.pop("position_ids", None) |
| attention_mask = kwargs.pop("attention_mask", None) |
| if "inputs_embeds" in kwargs: |
| raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
| if images is not None: |
| (_, position_ids, attention_mask, _, inputs_embeds, _) = ( |
| self.prepare_inputs_labels_for_multimodal( |
| inputs, |
| position_ids, |
| attention_mask, |
| None, |
| None, |
| images, |
| ) |
| ) |
| else: |
| inputs_embeds = self.get_input_embeddings()(inputs) |
|
|
| return super().generate( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| **kwargs, |
| ) |
|
|
| @override |
| def prepare_inputs_for_generation( |
| self: Any, |
| input_ids: Tensor, |
| past_key_values: list[FloatTensor] | None = None, |
| inputs_embeds: Tensor | None = None, |
| **kwargs: Any, |
| ): |
| images = kwargs.pop("images", None) |
| image_sizes = kwargs.pop("image_sizes", None) |
| inputs = super().prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| **kwargs, |
| ) |
| inputs.pop("cache_position") |
| if images is not None: |
| inputs["images"] = images |
| if image_sizes is not None: |
| inputs["image_sizes"] = image_sizes |
| return inputs |
|
|
| def encode_images(self: Any, images: list[Tensor] | Tensor) -> list[Tensor] | Tensor: |
| if type(images) is list: |
| image_features: list[Tensor] | Tensor = [] |
| for image in images: |
| outputs = self.model.vision_model( |
| image.unsqueeze(0), |
| output_hidden_states=True, |
| ) |
| hidden_states: Tensor = outputs.hidden_states[self.output_layer].to(image.dtype) |
| if not self.config.vision_config.use_cls_token: |
| image_features.append(hidden_states[:, 1:]) |
| else: |
| image_features.append(hidden_states) |
| else: |
| outputs = self.model.vision_model( |
| images, |
| output_hidden_states=True, |
| ) |
| hidden_states = outputs.hidden_states[self.config.vision_config.output_layer].to( |
| images.dtype |
| ) |
| if not self.config.vision_config.use_cls_token: |
| image_features = hidden_states[:, 1:] |
| else: |
| image_features = hidden_states |
| image_features = self.model.connector(image_features) |
|
|
| return image_features |
|
|
| def unpad_image(self: Any, tensor: Tensor, original_size: tuple[int, int]) -> Tensor: |
| """ |
| Unpads a PyTorch tensor of a padded and resized image. |
| |
| Args: |
| tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. |
| original_size (tuple): The original size of PIL image (width, height). |
| |
| Returns: |
| torch.Tensor: The unpadded image tensor. |
| """ |
| original_width, original_height = original_size |
| current_height, current_width = tensor.shape[1:] |
|
|
| original_aspect_ratio = original_width / original_height |
| current_aspect_ratio = current_width / current_height |
|
|
| if original_aspect_ratio > current_aspect_ratio: |
| scale_factor = current_width / original_width |
| new_height = int(original_height * scale_factor) |
| padding = (current_height - new_height) // 2 |
| unpadded_tensor = tensor[:, padding : current_height - padding, :] |
| else: |
| scale_factor = current_height / original_height |
| new_width = int(original_width * scale_factor) |
| padding = (current_width - new_width) // 2 |
| unpadded_tensor = tensor[:, :, padding : current_width - padding] |
|
|
| return unpadded_tensor |
|
|
| def prepare_inputs_labels_for_multimodal( |
| self: Any, |
| input_ids: Tensor | None = None, |
| position_ids: LongTensor | None = None, |
| attention_mask: Tensor | None = None, |
| past_key_values: list[FloatTensor] | None = None, |
| labels: LongTensor | None = None, |
| images: FloatTensor | None = None, |
| ) -> tuple[ |
| Tensor | None, |
| LongTensor | None, |
| Tensor | None, |
| list[FloatTensor] | None, |
| Tensor | None, |
| LongTensor | None, |
| ]: |
| vision_model = self.model.vision_model |
| if vision_model is None or images is None or input_ids.shape[1] == 1: |
| return input_ids, position_ids, attention_mask, past_key_values, None, labels |
|
|
| if isinstance(images, list) or images.ndim == 5: |
| if isinstance(images, list): |
| images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] |
| concat_images = torch.cat([image for image in images], dim=0) |
| image_features = self.encode_images(concat_images) |
| split_sizes = [image.shape[0] for image in images] |
| image_features: tuple[Tensor, ...] = torch.split(image_features, split_sizes, dim=0) |
| image_features = [x.flatten(0, 1) for x in image_features] |
| else: |
| image_features = self.encode_images(images) |
|
|
| |
| |
| |
| |
| _labels = labels |
| _position_ids = position_ids |
| _attention_mask = attention_mask |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| else: |
| attention_mask = attention_mask.bool() |
| if position_ids is None: |
| position_ids = torch.arange( |
| 0, |
| input_ids.shape[1], |
| dtype=torch.long, |
| device=input_ids.device, |
| ) |
| if labels is None: |
| labels = torch.full_like(input_ids, self.config.ignore_index) |
|
|
| |
| _input_ids = input_ids |
| input_ids = [ |
| cur_input_ids[cur_attention_mask] |
| for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask, strict=False) |
| ] |
| labels = [ |
| cur_labels[cur_attention_mask] |
| for cur_labels, cur_attention_mask in zip(labels, attention_mask, strict=False) |
| ] |
|
|
| new_input_embeds = [] |
| new_labels = [] |
| cur_image_idx = 0 |
| for batch_idx, cur_input_ids in enumerate(input_ids): |
| num_images = (cur_input_ids == self.config.image_token_index).sum() |
| if num_images == 0: |
| cur_image_features = image_features[cur_image_idx] |
| cur_input_embeds_1 = self.get_input_embeddings()(cur_input_ids) |
| cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) |
| new_input_embeds.append(cur_input_embeds) |
| new_labels.append(labels[batch_idx]) |
| cur_image_idx += 1 |
| continue |
|
|
| image_token_indices = ( |
| [-1] |
| + torch.where(cur_input_ids == self.config.image_token_index)[0].tolist() |
| + [cur_input_ids.shape[0]] |
| ) |
| cur_input_ids_noim = [] |
| cur_labels = labels[batch_idx] |
| cur_labels_noim = [] |
| for i in range(len(image_token_indices) - 1): |
| cur_input_ids_noim.append( |
| cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]] |
| ) |
| cur_labels_noim.append( |
| cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]] |
| ) |
| split_sizes = [x.shape[0] for x in cur_labels_noim] |
| cur_input_embeds = self.get_input_embeddings()(torch.cat(cur_input_ids_noim)) |
| cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) |
| cur_new_input_embeds = [] |
| cur_new_labels = [] |
|
|
| for i in range(num_images + 1): |
| cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
| cur_new_labels.append(cur_labels_noim[i]) |
| if i < num_images: |
| cur_image_features = image_features[cur_image_idx] |
| cur_image_idx += 1 |
| cur_new_input_embeds.append(cur_image_features) |
| cur_new_labels.append( |
| torch.full( |
| (cur_image_features.shape[0],), |
| self.config.ignore_index, |
| device=cur_labels.device, |
| dtype=cur_labels.dtype, |
| ) |
| ) |
|
|
| cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] |
|
|
| cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
| cur_new_labels = torch.cat(cur_new_labels) |
|
|
| new_input_embeds.append(cur_new_input_embeds) |
| new_labels.append(cur_new_labels) |
|
|
| |
| tokenizer_model_max_length = self.config.max_seq_length |
| if tokenizer_model_max_length is not None: |
| new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] |
| new_labels = [x[:tokenizer_model_max_length] for x in new_labels] |
|
|
| |
| max_len = max(x.shape[0] for x in new_input_embeds) |
| batch_size = len(new_input_embeds) |
|
|
| new_input_embeds_padded = [] |
| new_labels_padded = torch.full( |
| (batch_size, max_len), |
| self.config.ignore_index, |
| dtype=new_labels[0].dtype, |
| device=new_labels[0].device, |
| ) |
| attention_mask = torch.zeros( |
| (batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device |
| ) |
| position_ids = torch.zeros( |
| (batch_size, max_len), |
| dtype=position_ids.dtype, |
| device=position_ids.device, |
| ) |
|
|
| for i, (cur_new_embed, cur_new_labels) in enumerate( |
| zip(new_input_embeds, new_labels, strict=False) |
| ): |
| cur_len = cur_new_embed.shape[0] |
| if self.config.padding_side == "left": |
| new_input_embeds_padded.append( |
| torch.cat( |
| ( |
| torch.zeros( |
| (max_len - cur_len, cur_new_embed.shape[1]), |
| dtype=cur_new_embed.dtype, |
| device=cur_new_embed.device, |
| ), |
| cur_new_embed, |
| ), |
| dim=0, |
| ) |
| ) |
| if cur_len > 0: |
| new_labels_padded[i, -cur_len:] = cur_new_labels |
| attention_mask[i, -cur_len:] = True |
| position_ids[i, -cur_len:] = torch.arange( |
| 0, |
| cur_len, |
| dtype=position_ids.dtype, |
| device=position_ids.device, |
| ) |
| else: |
| new_input_embeds_padded.append( |
| torch.cat( |
| ( |
| cur_new_embed, |
| torch.zeros( |
| (max_len - cur_len, cur_new_embed.shape[1]), |
| dtype=cur_new_embed.dtype, |
| device=cur_new_embed.device, |
| ), |
| ), |
| dim=0, |
| ) |
| ) |
| if cur_len > 0: |
| new_labels_padded[i, :cur_len] = cur_new_labels |
| attention_mask[i, :cur_len] = True |
| position_ids[i, :cur_len] = torch.arange( |
| 0, |
| cur_len, |
| dtype=position_ids.dtype, |
| device=position_ids.device, |
| ) |
|
|
| new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
|
|
| if _labels is None: |
| new_labels = None |
| else: |
| new_labels = new_labels_padded |
|
|
| if _attention_mask is None: |
| attention_mask = None |
| else: |
| attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
| if _position_ids is None: |
| position_ids = None |
|
|
| return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels |
|
|
|
|
| AutoModel.register(VLMConfig, VLMForCausalLM) |
|
|