NVILA-Lite-2B-hf-preview / modeling_vila_hf.py
AndyZijianZhang's picture
feat: new model
10d7cc5
raw
history blame
5.84 kB
import os
from typing import Optional, Type, Union, cast, override
import transformers.modeling_utils as modeling_utils
from torch import FloatTensor, LongTensor, Tensor
from transformers.configuration_utils import PretrainedConfig
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
from .configuration_vila import VILAConfig
from .modeling_vila import VILAForCausalLM
IMAGE_TOKEN_ID = 151649
class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin):
config_class: Type[PretrainedConfig] = VILAConfig
base_model_prefix: str = "vila"
is_parallelizable: bool = True
main_input_name: str = "input_ids"
config: PretrainedConfig
mm_projector: PreTrainedModel
llm: Qwen2ForCausalLM
vision_tower: PreTrainedModel
def __init__(
self,
config: PretrainedConfig,
model: VILAForCausalLM,
*args,
**kwargs,
):
super().__init__(config, *args, **kwargs)
self.mm_projector = cast(PreTrainedModel, model.mm_projector)
self.llm = cast(Qwen2ForCausalLM, model.llm)
self.vision_tower = cast(PreTrainedModel, model.vision_tower)
def forward(
self,
*,
attention_mask: Optional[Tensor] = None,
input_ids: Optional[LongTensor] = None,
inputs_embeds: Optional[FloatTensor] = None,
pixel_values: Optional[Tensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
# Vision info is only used for prefilling.
if kwargs.get("past_key_values", None) is not None:
pixel_values = None
if inputs_embeds is None:
assert input_ids is not None
inputs_embeds = self._embed(input_ids, pixel_values)
else:
assert input_ids is None
assert pixel_values is None
outputs = self.llm.forward(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
**kwargs,
)
return outputs
@override
@classmethod
@modeling_utils.restore_default_torch_dtype
def from_pretrained(
cls: Type[modeling_utils.SpecificPreTrainedModelType],
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: Optional[bool] = None,
weights_only: bool = True,
**kwargs,
) -> modeling_utils.SpecificPreTrainedModelType:
state_dict = kwargs.pop("state_dict", None)
if pretrained_model_name_or_path is not None:
config = VILAConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
revision=revision,
use_safetensors=use_safetensors,
**kwargs,
)
else:
assert (
config is not None and state_dict is not None
), "Both config and state_dict must be provided if pretrained_model_name_or_path is None."
inner_model = VILAForCausalLM.from_pretrained(
pretrained_model_name_or_path, # type: ignore
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
weights_only=weights_only,
**kwargs,
)
state_dict = inner_model.state_dict()
# Prefix keys with "model.".
# state_dict = {f"model.{k}": v for k, v in state_dict.items()}
return super().from_pretrained(
None,
inner_model,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
state_dict=state_dict,
use_safetensors=use_safetensors,
weights_only=weights_only,
**kwargs,
)
def _embed(
self,
input_ids: LongTensor,
pixel_values: Optional[Tensor],
) -> FloatTensor:
"""Gets the embedding of the input ids and pixel values.
Args:
input_ids: The input ids.
pixel_values: The pixel values.
Returns:
The embedding of the input ids and pixel values.
"""
text_embedding = self.llm.get_input_embeddings().__call__(input_ids)
text_embedding = cast(FloatTensor, text_embedding)
if pixel_values is None:
return text_embedding
image_features: Tensor = self.vision_tower.__call__(pixel_values)
image_features: Tensor = self.mm_projector.__call__(image_features)
n_images, n_feature, dim_feature = image_features.shape
image_features = image_features.view(n_images * n_feature, dim_feature)
image_token_mask = input_ids == IMAGE_TOKEN_ID
text_embedding[image_token_mask] = image_features
return text_embedding