| import os |
| from typing import Optional, Type, Union, cast, override |
|
|
| import transformers.modeling_utils as modeling_utils |
| from torch import FloatTensor, LongTensor, Tensor |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.generation.utils import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM |
|
|
| from .configuration_vila import VILAConfig |
| from .modeling_vila import VILAForCausalLM |
|
|
| IMAGE_TOKEN_ID = 151649 |
|
|
|
|
| class VILAForConditionalGeneration(PreTrainedModel, GenerationMixin): |
| config_class: Type[PretrainedConfig] = VILAConfig |
| base_model_prefix: str = "vila" |
| is_parallelizable: bool = True |
| main_input_name: str = "input_ids" |
|
|
| config: PretrainedConfig |
|
|
| mm_projector: PreTrainedModel |
| llm: Qwen2ForCausalLM |
| vision_tower: PreTrainedModel |
|
|
| def __init__( |
| self, |
| config: PretrainedConfig, |
| model: VILAForCausalLM, |
| *args, |
| **kwargs, |
| ): |
| super().__init__(config, *args, **kwargs) |
|
|
| self.mm_projector = cast(PreTrainedModel, model.mm_projector) |
| self.llm = cast(Qwen2ForCausalLM, model.llm) |
| self.vision_tower = cast(PreTrainedModel, model.vision_tower) |
|
|
| def forward( |
| self, |
| *, |
| attention_mask: Optional[Tensor] = None, |
| input_ids: Optional[LongTensor] = None, |
| inputs_embeds: Optional[FloatTensor] = None, |
| pixel_values: Optional[Tensor] = None, |
| **kwargs, |
| ) -> CausalLMOutputWithPast: |
| |
| if kwargs.get("past_key_values", None) is not None: |
| pixel_values = None |
|
|
| if inputs_embeds is None: |
| assert input_ids is not None |
|
|
| inputs_embeds = self._embed(input_ids, pixel_values) |
| else: |
| assert input_ids is None |
| assert pixel_values is None |
|
|
| outputs = self.llm.forward( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| **kwargs, |
| ) |
|
|
| return outputs |
|
|
| @override |
| @classmethod |
| @modeling_utils.restore_default_torch_dtype |
| def from_pretrained( |
| cls: Type[modeling_utils.SpecificPreTrainedModelType], |
| pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
| *model_args, |
| config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
| cache_dir: Optional[Union[str, os.PathLike]] = None, |
| ignore_mismatched_sizes: bool = False, |
| force_download: bool = False, |
| local_files_only: bool = False, |
| token: Optional[Union[str, bool]] = None, |
| revision: str = "main", |
| use_safetensors: Optional[bool] = None, |
| weights_only: bool = True, |
| **kwargs, |
| ) -> modeling_utils.SpecificPreTrainedModelType: |
| state_dict = kwargs.pop("state_dict", None) |
|
|
| if pretrained_model_name_or_path is not None: |
| config = VILAConfig.from_pretrained( |
| pretrained_model_name_or_path, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| local_files_only=local_files_only, |
| revision=revision, |
| use_safetensors=use_safetensors, |
| **kwargs, |
| ) |
| else: |
| assert ( |
| config is not None and state_dict is not None |
| ), "Both config and state_dict must be provided if pretrained_model_name_or_path is None." |
|
|
| inner_model = VILAForCausalLM.from_pretrained( |
| pretrained_model_name_or_path, |
| *model_args, |
| config=config, |
| cache_dir=cache_dir, |
| ignore_mismatched_sizes=ignore_mismatched_sizes, |
| force_download=force_download, |
| local_files_only=local_files_only, |
| token=token, |
| revision=revision, |
| use_safetensors=use_safetensors, |
| weights_only=weights_only, |
| **kwargs, |
| ) |
|
|
| state_dict = inner_model.state_dict() |
|
|
| |
| |
|
|
| return super().from_pretrained( |
| None, |
| inner_model, |
| *model_args, |
| config=config, |
| cache_dir=cache_dir, |
| ignore_mismatched_sizes=ignore_mismatched_sizes, |
| force_download=force_download, |
| local_files_only=local_files_only, |
| token=token, |
| revision=revision, |
| state_dict=state_dict, |
| use_safetensors=use_safetensors, |
| weights_only=weights_only, |
| **kwargs, |
| ) |
|
|
| def _embed( |
| self, |
| input_ids: LongTensor, |
| pixel_values: Optional[Tensor], |
| ) -> FloatTensor: |
| """Gets the embedding of the input ids and pixel values. |
| |
| Args: |
| input_ids: The input ids. |
| pixel_values: The pixel values. |
| |
| Returns: |
| The embedding of the input ids and pixel values. |
| """ |
|
|
| text_embedding = self.llm.get_input_embeddings().__call__(input_ids) |
| text_embedding = cast(FloatTensor, text_embedding) |
|
|
| if pixel_values is None: |
| return text_embedding |
|
|
| image_features: Tensor = self.vision_tower.__call__(pixel_values) |
| image_features: Tensor = self.mm_projector.__call__(image_features) |
|
|
| n_images, n_feature, dim_feature = image_features.shape |
| image_features = image_features.view(n_images * n_feature, dim_feature) |
|
|
| image_token_mask = input_ids == IMAGE_TOKEN_ID |
|
|
| text_embedding[image_token_mask] = image_features |
|
|
| return text_embedding |
|
|