| from transformers import AutoModel, AutoModelForCausalLM |
| from transformers.activations import ACT2FN |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.generation.utils import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| import torch |
| from torch import nn |
| from torch.nn import RMSNorm |
| from typing import List, Optional |
|
|
| from .configuration_qualityv import QualityvConfig, QualityLinearAdapterConfig |
|
|
|
|
| class QualityLinearAdapter(nn.Module): |
| def __init__(self, config: QualityLinearAdapterConfig): |
| super().__init__() |
| self.config = config |
| self.norm = RMSNorm(config.in_hidden_size) |
| self.act_fn = ACT2FN[config.act_fn] |
| if config.num_layers == 1: |
| self.linears = nn.Linear(config.in_hidden_size, config.out_hidden_size) |
| else: |
| model_list = [] |
| for _ in range(config.num_layers - 1): |
| model_list.append(nn.Linear(config.in_hidden_size, config.intermediate_size)) |
| model_list.append(self.act_fn) |
| model_list.append(nn.Linear(config.intermediate_size, config.out_hidden_size)) |
| self.linears = nn.Sequential(*model_list) |
| |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.linears(self.norm(x)) |
| return x |
| |
| |
| class QualityvForCausalLM(PreTrainedModel, GenerationMixin): |
| |
| |
| def __init__(self, config: QualityvConfig, *args, **kwargs): |
| super().__init__(config, *args, **kwargs) |
| self.config = config |
| self.llm_model = AutoModelForCausalLM.from_pretrained(config.llm_model_name) |
| if config.vision_config is not None: |
| self.vision_model = AutoModel.from_pretrained(config.vision_model_name) |
| self.vision_adapter = QualityLinearAdapter(config.vision_adapter_config) |
| if config.audio_config is not None: |
| self.audio_model = AutoModel.from_pretrained(config.audio_model_name) |
| self.audio_adapter = QualityLinearAdapter(config.audio_adapter_config) |
| self.decoder_input_ids = torch.tensor([[1, 1,]]) * self.audio_model.config.decoder_start_token_id |
| self.post_init() |
| |
| def get_input_embeddings(self): |
| return self.llm_model.get_input_embeddings() |
| |
| def set_input_embeddings(self, value): |
| self.llm_model.set_input_embeddings(value) |
| |
| def get_output_embeddings(self): |
| return self.llm_model.get_output_embeddings() |
| |
| def set_output_embeddings(self, value): |
| self.llm_model.set_output_embeddings(value) |
|
|
| def set_decoder(self, decoder): |
| self.llm_model.set_decoder(decoder) |
| |
| def get_decoder(self): |
| return self.llm_model.get_decoder() |
| |
| def get_vision_model(self): |
| return self.vision_model |
| |
| def get_audio_model(self): |
| return self.audio_model |
| |
| def get_video_features(self, pixel_values_videos: torch.Tensor) -> torch.Tensor: |
| video_embeds = self.vision_model(pixel_values_videos).last_hidden_state |
| video_embeds = self.vision_adapter(video_embeds) |
| return video_embeds |
| |
| def get_audio_features(self, audio_values: torch.Tensor) -> torch.Tensor: |
| audio_embeds = self.audio_model.encoder(audio_values).last_hidden_state |
| audio_embeds = self.audio_adapter(audio_embeds) |
| return audio_embeds |
| |
| def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| image_embeds = self.vision_model(pixel_values).last_hidden_state |
| image_embeds = self.vision_adapter(image_embeds) |
| return image_embeds |
| |
| def replace_multi_modal_embeddings(self, multi_modal_embeds: torch.Tensor, |
| input_embeds: torch.Tensor, |
| input_ids: torch.LongTensor, |
| multi_modal_token_id: int, |
| note: str="multi_modal"): |
| |
| |
| |
| |
| |
| hidden_size = multi_modal_embeds.shape[-1] |
| multi_modal_embeds = multi_modal_embeds.view(-1, hidden_size) |
| n_modal_tokens = (input_ids == multi_modal_token_id).sum() |
| n_modal_embeds = multi_modal_embeds.shape[0] |
| if n_modal_tokens != n_modal_embeds: |
| raise ValueError(f"The number of {note} tokens ({n_modal_tokens}) does not match the number of {note} embeddings ({n_modal_embeds}).") |
| mask = input_ids == multi_modal_token_id |
| mask_unsqueezed = mask.unsqueeze(-1) |
| mask_expanded = mask_unsqueezed.expand_as(input_embeds) |
| video_mask = mask_expanded.to(input_embeds.device) |
| multi_modal_embeds = multi_modal_embeds.to(input_embeds.device, dtype=input_embeds.dtype) |
| input_embeds = input_embeds.masked_scatter(video_mask, multi_modal_embeds) |
| return input_embeds |
| |
| |
| def forward(self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| pixel_values_videos: Optional[torch.FloatTensor] = None, |
| audio_values: Optional[torch.FloatTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs |
| ): |
| output_attentions = output_attentions if output_attentions is not None else self.config.llm_config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.llm_config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.llm_config.use_return_dict |
| |
| if inputs_embeds is None: |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
| |
| if pixel_values_videos is not None: |
| video_features = self.get_video_features(pixel_values_videos) |
| inputs_embeds = self.replace_multi_modal_embeddings(video_features, inputs_embeds, input_ids, self.config.video_token_id, note="video") |
| |
| if pixel_values is not None: |
| image_features = self.get_image_features(pixel_values) |
| inputs_embeds = self.replace_multi_modal_embeddings(image_features, inputs_embeds, input_ids, self.config.image_token_id, note="image") |
| |
| if audio_values is not None: |
| audio_features = self.get_audio_features(audio_values) |
| inputs_embeds = self.replace_multi_modal_embeddings(audio_features, inputs_embeds, input_ids, self.config.audio_token_id, note="audio") |
| |
| outputs = self.llm_model( |
| input_ids=None, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| **kwargs |
| ) |
| |
| return outputs |
| |
| |
| |
| def prepare_inputs_for_generation(self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| use_cache=None, |
| pixel_values=None, |
| pixel_values_videos=None, |
| audio_values=None, |
| cache_position=None, |
| **kwargs): |
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids=input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| use_cache=use_cache, |
| pixel_values=pixel_values, |
| pixel_values_videos=pixel_values_videos, |
| audio_values=audio_values, |
| **kwargs |
| ) |
| if cache_position[0] != 0: |
| model_inputs["pixel_values"] = None |
| model_inputs["pixel_values_videos"] = None |
| return model_inputs |
| |
| def _expand_inputs_for_generation(self, |
| expand_size: int = 1, |
| is_encoder_decoder: bool = False, |
| input_ids: Optional[torch.LongTensor] = None, |
| **model_kwargs, |
| ): |
| """Expands input tensors for generation when using beam search or sampling. |
| |
| Args: |
| expand_size (int, optional): The size to expand the inputs by. Defaults to 1. |
| is_encoder_decoder (bool, optional): Whether the model is an encoder-decoder model. Defaults to False. |
| input_ids (Optional[torch.LongTensor], optional): The input token IDs. Defaults to None. |
| **model_kwargs: Additional model-specific keyword arguments. |
| |
| Returns: |
| Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: The expanded input_ids and model_kwargs. |
| """ |
| if input_ids is not None: |
| input_ids = input_ids.repeat_interleave(expand_size, dim=0) |
|
|
| |
| if "attention_mask" in model_kwargs: |
| model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat_interleave(expand_size, dim=0) |
|
|
| |
| if "position_ids" in model_kwargs: |
| model_kwargs["position_ids"] = model_kwargs["position_ids"].repeat_interleave(expand_size, dim=0) |
|
|
| |
| if "pixel_values" in model_kwargs and model_kwargs["pixel_values"] is not None: |
| model_kwargs["pixel_values"] = model_kwargs["pixel_values"].repeat_interleave(expand_size, dim=0) |
|
|
| |
| if "pixel_values_videos" in model_kwargs and model_kwargs["pixel_values_videos"] is not None: |
| model_kwargs["pixel_values_videos"] = model_kwargs["pixel_values_videos"].repeat_interleave(expand_size, dim=0) |
|
|
| |
| if "audio_values" in model_kwargs and model_kwargs["audio_values"] is not None: |
| model_kwargs["audio_values"] = model_kwargs["audio_values"].repeat_interleave(expand_size, dim=0) |
|
|
| |
| if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: |
| model_kwargs["cache_position"] = model_kwargs["cache_position"].repeat_interleave(expand_size, dim=0) |
|
|
| return input_ids, model_kwargs |
|
|