NVILA-Lite-2B-hf-0626 / modeling_vila.py
AndyZijianZhang's picture
Upload files with `vila-upload`.
7d97786 verified
from typing import List, Optional, Type, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, Tensor
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from .configuration_vila import VILAConfig
class DownSample3x3BlockFix(nn.Module):
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
Returns:
The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
"""
batch_size, sequence_length, hidden_size = x.shape
feat_size = int(sequence_length**0.5)
if feat_size**2 != sequence_length:
raise ValueError(f"Cannot take square root: sequence_length {sequence_length} is not a perfect square")
features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
pad_after = (3 - feat_size % 3) % 3
if pad_after > 0:
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
feat_size = feat_size + pad_after
features = features.reshape(batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size)
features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
features = features.reshape(batch_size, -1, 9 * hidden_size)
return features
class MultimodalProjector(nn.Module):
layers: nn.Sequential
def __init__(
self,
config: VILAConfig,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
if config.mm_projector_type == "mlp_downsample_3x3_fix":
self.layers = nn.Sequential(
DownSample3x3BlockFix(),
nn.LayerNorm(config.mm_hidden_size * 9),
nn.Linear(
config.mm_hidden_size * 9,
config.mm_hidden_size * 3,
),
nn.GELU(),
nn.LayerNorm(config.vision_config.hidden_size * 3),
nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, config.hidden_size),
)
else:
raise NotImplementedError(f"Unsupported mm_projector_type: {config.mm_projector_type}")
self.layers.type(config.torch_dtype)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
Returns:
The output tensor of shape (batch_size, image_pad_len, hidden_size).
"""
return self.layers(x.to(device=self.device, dtype=self.dtype))
class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
config_class: Type[PretrainedConfig] = VILAConfig
base_model_prefix: str = "llm"
_auto_class = "AutoModelForImageTextToText"
_no_split_modules: List[str] = ["MultimodalProjector"]
_skip_keys_device_placement: List[str] = ["past_key_values"]
supports_gradient_checkpointing = True
_supports_flash_attn_2: bool = True
_supports_sdpa = True
config: VILAConfig
llm: Qwen2ForCausalLM
mm_projector: MultimodalProjector
vision_tower: SiglipVisionModel
def __init__(
self,
config: VILAConfig,
*args,
**kwargs,
):
super().__init__(config, *args, **kwargs)
self.llm = Qwen2ForCausalLM._from_config(config.text_config, *args, **kwargs)
self.mm_projector = MultimodalProjector(config)
self.vision_tower = SiglipVisionModel._from_config(config.vision_config, *args, **kwargs)
self.post_init()
def forward(
self,
*,
attention_mask: Optional[Tensor] = None,
input_ids: Optional[Tensor] = None,
inputs_embeds: Optional[Tensor] = None,
past_key_values: Optional[Cache] = None,
pixel_values: Optional[Tensor] = None,
position_ids: Optional[LongTensor] = None,
logits_to_keep: Union[int, Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")
if past_key_values is None: # Prefill
if input_ids is not None:
inputs_embeds = self._embed(input_ids, pixel_values)
input_ids = None
outputs = self.llm.__call__(
attention_mask=(attention_mask.to(device=self.llm.device) if attention_mask is not None else None),
input_ids=(input_ids.to(device=self.llm.device) if input_ids is not None else None),
inputs_embeds=(
inputs_embeds.to(device=self.llm.device, dtype=self.llm.dtype) if inputs_embeds is not None else None
),
past_key_values=past_key_values,
position_ids=(position_ids.to(device=self.llm.device) if position_ids is not None else None),
logits_to_keep=logits_to_keep,
**kwargs,
)
return outputs
def get_output_embeddings(self) -> nn.Module:
return self.llm.get_output_embeddings()
def _embed(
self,
input_ids: Tensor,
pixel_values: Optional[Tensor],
) -> Tensor:
"""Gets the embedding of the input ids and pixel values.
Args:
input_ids: The input ids.
pixel_values: The pixel values.
Returns:
The embedding of the input ids and pixel values.
"""
if torch.any(input_ids == self.config.video_token_id):
raise ValueError("Video token ids should not be present in the input ids.")
image_token_mask = input_ids == self.config.image_token_id
text_embedding: Tensor = self.llm.get_input_embeddings().__call__(input_ids * ~image_token_mask)
if pixel_values is None:
return text_embedding
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__(
pixel_values.to(device=self.vision_tower.device, dtype=self.vision_tower.dtype),
output_hidden_states=True,
)
mm_projector_input = self._vision_tower_output_to_mm_projector_input(vision_tower_output)
image_embedding: Tensor = self.mm_projector.__call__(
mm_projector_input.to(device=self.mm_projector.device, dtype=self.mm_projector.dtype)
)
image_embedding = image_embedding.reshape(-1, image_embedding.shape[-1])
text_embedding.masked_scatter_(
image_token_mask.to(device=text_embedding.device, dtype=torch.bool).unsqueeze(-1),
image_embedding.to(device=text_embedding.device, dtype=text_embedding.dtype).flatten(),
)
return text_embedding
def _vision_tower_output_to_mm_projector_input(
self,
vision_tower_output: BaseModelOutputWithPooling,
) -> Tensor:
assert vision_tower_output.hidden_states is not None
selected_layer_hidden_states = vision_tower_output.hidden_states[self.config.mm_vision_select_layer]
if self.config.mm_vision_select_feature == "cls_patch":
return selected_layer_hidden_states
else:
raise NotImplementedError(f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}")