| from typing import List, Optional, Type, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import LongTensor, Tensor | |
| from transformers.cache_utils import Cache | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.generation.utils import GenerationMixin | |
| from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM | |
| from transformers.models.siglip.modeling_siglip import SiglipVisionModel | |
| from .configuration_vila import VILAConfig | |
| class DownSample3x3BlockFix(nn.Module): | |
| def forward(self, x: Tensor) -> Tensor: | |
| """ | |
| Args: | |
| x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size). | |
| Returns: | |
| The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9). | |
| """ | |
| batch_size, sequence_length, hidden_size = x.shape | |
| feat_size = int(sequence_length**0.5) | |
| if feat_size**2 != sequence_length: | |
| raise ValueError(f"Cannot take square root: sequence_length {sequence_length} is not a perfect square") | |
| features = x.reshape(batch_size, feat_size, feat_size, hidden_size) | |
| pad_after = (3 - feat_size % 3) % 3 | |
| if pad_after > 0: | |
| features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after)) | |
| feat_size = feat_size + pad_after | |
| features = features.reshape(batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size) | |
| features = features.permute(0, 1, 3, 2, 4, 5).contiguous() | |
| features = features.reshape(batch_size, -1, 9 * hidden_size) | |
| return features | |
| class MultimodalProjector(nn.Module): | |
| layers: nn.Sequential | |
| def __init__( | |
| self, | |
| config: VILAConfig, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| if config.mm_projector_type == "mlp_downsample_3x3_fix": | |
| self.layers = nn.Sequential( | |
| DownSample3x3BlockFix(), | |
| nn.LayerNorm(config.mm_hidden_size * 9), | |
| nn.Linear( | |
| config.mm_hidden_size * 9, | |
| config.mm_hidden_size * 3, | |
| ), | |
| nn.GELU(), | |
| nn.LayerNorm(config.vision_config.hidden_size * 3), | |
| nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size), | |
| nn.GELU(), | |
| nn.Linear(config.hidden_size, config.hidden_size), | |
| ) | |
| else: | |
| raise NotImplementedError(f"Unsupported mm_projector_type: {config.mm_projector_type}") | |
| self.layers.type(config.torch_dtype) | |
| def device(self) -> torch.device: | |
| return next(self.parameters()).device | |
| def dtype(self) -> torch.dtype: | |
| return next(self.parameters()).dtype | |
| def forward(self, x: Tensor) -> Tensor: | |
| """ | |
| Args: | |
| x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size). | |
| Returns: | |
| The output tensor of shape (batch_size, image_pad_len, hidden_size). | |
| """ | |
| return self.layers(x.to(device=self.device, dtype=self.dtype)) | |
| class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin): | |
| config_class: Type[PretrainedConfig] = VILAConfig | |
| base_model_prefix: str = "llm" | |
| _auto_class = "AutoModelForImageTextToText" | |
| _no_split_modules: List[str] = ["MultimodalProjector"] | |
| _skip_keys_device_placement: List[str] = ["past_key_values"] | |
| supports_gradient_checkpointing = True | |
| _supports_flash_attn_2: bool = True | |
| _supports_sdpa = True | |
| config: VILAConfig | |
| llm: Qwen2ForCausalLM | |
| mm_projector: MultimodalProjector | |
| vision_tower: SiglipVisionModel | |
| def __init__( | |
| self, | |
| config: VILAConfig, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__(config, *args, **kwargs) | |
| self.llm = Qwen2ForCausalLM._from_config(config.text_config, *args, **kwargs) | |
| self.mm_projector = MultimodalProjector(config) | |
| self.vision_tower = SiglipVisionModel._from_config(config.vision_config, *args, **kwargs) | |
| self.post_init() | |
| def forward( | |
| self, | |
| *, | |
| attention_mask: Optional[Tensor] = None, | |
| input_ids: Optional[Tensor] = None, | |
| inputs_embeds: Optional[Tensor] = None, | |
| past_key_values: Optional[Cache] = None, | |
| pixel_values: Optional[Tensor] = None, | |
| position_ids: Optional[LongTensor] = None, | |
| logits_to_keep: Union[int, Tensor] = 0, | |
| **kwargs, | |
| ) -> CausalLMOutputWithPast: | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds.") | |
| if past_key_values is None: # Prefill | |
| if input_ids is not None: | |
| inputs_embeds = self._embed(input_ids, pixel_values) | |
| input_ids = None | |
| outputs = self.llm.__call__( | |
| attention_mask=(attention_mask.to(device=self.llm.device) if attention_mask is not None else None), | |
| input_ids=(input_ids.to(device=self.llm.device) if input_ids is not None else None), | |
| inputs_embeds=( | |
| inputs_embeds.to(device=self.llm.device, dtype=self.llm.dtype) if inputs_embeds is not None else None | |
| ), | |
| past_key_values=past_key_values, | |
| position_ids=(position_ids.to(device=self.llm.device) if position_ids is not None else None), | |
| logits_to_keep=logits_to_keep, | |
| **kwargs, | |
| ) | |
| return outputs | |
| def get_output_embeddings(self) -> nn.Module: | |
| return self.llm.get_output_embeddings() | |
| def _embed( | |
| self, | |
| input_ids: Tensor, | |
| pixel_values: Optional[Tensor], | |
| ) -> Tensor: | |
| """Gets the embedding of the input ids and pixel values. | |
| Args: | |
| input_ids: The input ids. | |
| pixel_values: The pixel values. | |
| Returns: | |
| The embedding of the input ids and pixel values. | |
| """ | |
| if torch.any(input_ids == self.config.video_token_id): | |
| raise ValueError("Video token ids should not be present in the input ids.") | |
| image_token_mask = input_ids == self.config.image_token_id | |
| text_embedding: Tensor = self.llm.get_input_embeddings().__call__(input_ids * ~image_token_mask) | |
| if pixel_values is None: | |
| return text_embedding | |
| vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__( | |
| pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype), | |
| output_hidden_states=True, | |
| ) | |
| mm_projector_input = self._vision_tower_output_to_mm_projector_input(vision_tower_output) | |
| image_embedding: Tensor = self.mm_projector.__call__( | |
| mm_projector_input.to(device=self.mm_projector.device, dtype=self.mm_projector.dtype) | |
| ) | |
| image_embedding = image_embedding.reshape(-1, image_embedding.shape[-1]) | |
| text_embedding.masked_scatter_( | |
| image_token_mask.to(device=text_embedding.device, dtype=torch.bool).unsqueeze(-1), | |
| image_embedding.to(device=text_embedding.device, dtype=text_embedding.dtype).flatten(), | |
| ) | |
| return text_embedding | |
| def _vision_tower_output_to_mm_projector_input( | |
| self, | |
| vision_tower_output: BaseModelOutputWithPooling, | |
| ) -> Tensor: | |
| assert vision_tower_output.hidden_states is not None | |
| selected_layer_hidden_states = vision_tower_output.hidden_states[self.config.mm_vision_select_layer] | |
| if self.config.mm_vision_select_feature == "cls_patch": | |
| return selected_layer_hidden_states | |
| else: | |
| raise NotImplementedError(f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}") | |