from transformers import AutoModel, AutoModelForCausalLM from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel from transformers.generation.utils import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast import torch from torch import nn from torch.nn import RMSNorm from typing import List, Optional from .configuration_qualityv import QualityvConfig, QualityLinearAdapterConfig class QualityLinearAdapter(nn.Module): def __init__(self, config: QualityLinearAdapterConfig): super().__init__() self.config = config self.norm = RMSNorm(config.in_hidden_size) self.act_fn = ACT2FN[config.act_fn] if config.num_layers == 1: self.linears = nn.Linear(config.in_hidden_size, config.out_hidden_size) else: model_list = [] for _ in range(config.num_layers - 1): model_list.append(nn.Linear(config.in_hidden_size, config.intermediate_size)) model_list.append(self.act_fn) model_list.append(nn.Linear(config.intermediate_size, config.out_hidden_size)) self.linears = nn.Sequential(*model_list) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linears(self.norm(x)) return x class QualityvForCausalLM(PreTrainedModel, GenerationMixin): def __init__(self, config: QualityvConfig, *args, **kwargs): super().__init__(config, *args, **kwargs) self.config = config self.llm_model = AutoModelForCausalLM.from_pretrained(config.llm_model_name) if config.vision_config is not None: self.vision_model = AutoModel.from_pretrained(config.vision_model_name) self.vision_adapter = QualityLinearAdapter(config.vision_adapter_config) if config.audio_config is not None: self.audio_model = AutoModel.from_pretrained(config.audio_model_name) self.audio_adapter = QualityLinearAdapter(config.audio_adapter_config) self.decoder_input_ids = torch.tensor([[1, 1,]]) * self.audio_model.config.decoder_start_token_id self.post_init() def get_input_embeddings(self): return self.llm_model.get_input_embeddings() def set_input_embeddings(self, value): self.llm_model.set_input_embeddings(value) def get_output_embeddings(self): return self.llm_model.get_output_embeddings() def set_output_embeddings(self, value): self.llm_model.set_output_embeddings(value) def set_decoder(self, decoder): self.llm_model.set_decoder(decoder) def get_decoder(self): return self.llm_model.get_decoder() def get_vision_model(self): return self.vision_model def get_audio_model(self): return self.audio_model def get_video_features(self, pixel_values_videos: torch.Tensor) -> torch.Tensor: video_embeds = self.vision_model(pixel_values_videos).last_hidden_state video_embeds = self.vision_adapter(video_embeds) return video_embeds def get_audio_features(self, audio_values: torch.Tensor) -> torch.Tensor: audio_embeds = self.audio_model.encoder(audio_values).last_hidden_state audio_embeds = self.audio_adapter(audio_embeds) return audio_embeds def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor: image_embeds = self.vision_model(pixel_values).last_hidden_state image_embeds = self.vision_adapter(image_embeds) return image_embeds def replace_multi_modal_embeddings(self, multi_modal_embeds: torch.Tensor, input_embeds: torch.Tensor, input_ids: torch.LongTensor, multi_modal_token_id: int, note: str="multi_modal"): # multi_modal_embeds: batch_size * num_frames, hidden_steps, hidden_size # input_embeds: batch_size, seq_length, hidden_size # input_ids: batch_size, seq_length # multi_modal_token_id: int # note: str hidden_size = multi_modal_embeds.shape[-1] multi_modal_embeds = multi_modal_embeds.view(-1, hidden_size) n_modal_tokens = (input_ids == multi_modal_token_id).sum() n_modal_embeds = multi_modal_embeds.shape[0] if n_modal_tokens != n_modal_embeds: raise ValueError(f"The number of {note} tokens ({n_modal_tokens}) does not match the number of {note} embeddings ({n_modal_embeds}).") mask = input_ids == multi_modal_token_id mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(input_embeds) video_mask = mask_expanded.to(input_embeds.device) multi_modal_embeds = multi_modal_embeds.to(input_embeds.device, dtype=input_embeds.dtype) input_embeds = input_embeds.masked_scatter(video_mask, multi_modal_embeds) return input_embeds def forward(self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, audio_values: Optional[torch.FloatTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs ): output_attentions = output_attentions if output_attentions is not None else self.config.llm_config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.llm_config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.llm_config.use_return_dict if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) if pixel_values_videos is not None: video_features = self.get_video_features(pixel_values_videos) inputs_embeds = self.replace_multi_modal_embeddings(video_features, inputs_embeds, input_ids, self.config.video_token_id, note="video") if pixel_values is not None: image_features = self.get_image_features(pixel_values) inputs_embeds = self.replace_multi_modal_embeddings(image_features, inputs_embeds, input_ids, self.config.image_token_id, note="image") if audio_values is not None: audio_features = self.get_audio_features(audio_values) inputs_embeds = self.replace_multi_modal_embeddings(audio_features, inputs_embeds, input_ids, self.config.audio_token_id, note="audio") outputs = self.llm_model( input_ids=None, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, **kwargs ) return outputs def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, pixel_values=None, pixel_values_videos=None, audio_values=None, cache_position=None, **kwargs): model_inputs = super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, use_cache=use_cache, pixel_values=pixel_values, pixel_values_videos=pixel_values_videos, audio_values=audio_values, **kwargs ) if cache_position[0] != 0: model_inputs["pixel_values"] = None model_inputs["pixel_values_videos"] = None return model_inputs def _expand_inputs_for_generation(self, expand_size: int = 1, is_encoder_decoder: bool = False, input_ids: Optional[torch.LongTensor] = None, **model_kwargs, ): """Expands input tensors for generation when using beam search or sampling. Args: expand_size (int, optional): The size to expand the inputs by. Defaults to 1. is_encoder_decoder (bool, optional): Whether the model is an encoder-decoder model. Defaults to False. input_ids (Optional[torch.LongTensor], optional): The input token IDs. Defaults to None. **model_kwargs: Additional model-specific keyword arguments. Returns: Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: The expanded input_ids and model_kwargs. """ if input_ids is not None: input_ids = input_ids.repeat_interleave(expand_size, dim=0) # Expand attention mask if present if "attention_mask" in model_kwargs: model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat_interleave(expand_size, dim=0) # Expand position IDs if present if "position_ids" in model_kwargs: model_kwargs["position_ids"] = model_kwargs["position_ids"].repeat_interleave(expand_size, dim=0) # Expand pixel values for images if present if "pixel_values" in model_kwargs and model_kwargs["pixel_values"] is not None: model_kwargs["pixel_values"] = model_kwargs["pixel_values"].repeat_interleave(expand_size, dim=0) # Expand pixel values for videos if present if "pixel_values_videos" in model_kwargs and model_kwargs["pixel_values_videos"] is not None: model_kwargs["pixel_values_videos"] = model_kwargs["pixel_values_videos"].repeat_interleave(expand_size, dim=0) # Expand audio values if present if "audio_values" in model_kwargs and model_kwargs["audio_values"] is not None: model_kwargs["audio_values"] = model_kwargs["audio_values"].repeat_interleave(expand_size, dim=0) # Expand cache position if present if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: model_kwargs["cache_position"] = model_kwargs["cache_position"].repeat_interleave(expand_size, dim=0) return input_ids, model_kwargs