qualityv-0606 / modeling_qualityv.py
CyberBoyNull's picture
Upload folder
cb65f9f verified
from transformers import AutoModel, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
import torch
from torch import nn
from torch.nn import RMSNorm
from typing import List, Optional
from .configuration_qualityv import QualityvConfig, QualityLinearAdapterConfig
class QualityLinearAdapter(nn.Module):
def __init__(self, config: QualityLinearAdapterConfig):
super().__init__()
self.config = config
self.norm = RMSNorm(config.in_hidden_size)
self.act_fn = ACT2FN[config.act_fn]
if config.num_layers == 1:
self.linears = nn.Linear(config.in_hidden_size, config.out_hidden_size)
else:
model_list = []
for _ in range(config.num_layers - 1):
model_list.append(nn.Linear(config.in_hidden_size, config.intermediate_size))
model_list.append(self.act_fn)
model_list.append(nn.Linear(config.intermediate_size, config.out_hidden_size))
self.linears = nn.Sequential(*model_list)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linears(self.norm(x))
return x
class QualityvForCausalLM(PreTrainedModel, GenerationMixin):
def __init__(self, config: QualityvConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.config = config
self.llm_model = AutoModelForCausalLM.from_pretrained(config.llm_model_name)
if config.vision_config is not None:
self.vision_model = AutoModel.from_pretrained(config.vision_model_name)
self.vision_adapter = QualityLinearAdapter(config.vision_adapter_config)
if config.audio_config is not None:
self.audio_model = AutoModel.from_pretrained(config.audio_model_name)
self.audio_adapter = QualityLinearAdapter(config.audio_adapter_config)
self.decoder_input_ids = torch.tensor([[1, 1,]]) * self.audio_model.config.decoder_start_token_id
self.post_init()
def get_input_embeddings(self):
return self.llm_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.llm_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.llm_model.get_output_embeddings()
def set_output_embeddings(self, value):
self.llm_model.set_output_embeddings(value)
def set_decoder(self, decoder):
self.llm_model.set_decoder(decoder)
def get_decoder(self):
return self.llm_model.get_decoder()
def get_vision_model(self):
return self.vision_model
def get_audio_model(self):
return self.audio_model
def get_video_features(self, pixel_values_videos: torch.Tensor) -> torch.Tensor:
video_embeds = self.vision_model(pixel_values_videos).last_hidden_state
video_embeds = self.vision_adapter(video_embeds)
return video_embeds
def get_audio_features(self, audio_values: torch.Tensor) -> torch.Tensor:
audio_embeds = self.audio_model.encoder(audio_values).last_hidden_state
audio_embeds = self.audio_adapter(audio_embeds)
return audio_embeds
def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_embeds = self.vision_model(pixel_values).last_hidden_state
image_embeds = self.vision_adapter(image_embeds)
return image_embeds
def replace_multi_modal_embeddings(self, multi_modal_embeds: torch.Tensor,
input_embeds: torch.Tensor,
input_ids: torch.LongTensor,
multi_modal_token_id: int,
note: str="multi_modal"):
# multi_modal_embeds: batch_size * num_frames, hidden_steps, hidden_size
# input_embeds: batch_size, seq_length, hidden_size
# input_ids: batch_size, seq_length
# multi_modal_token_id: int
# note: str
hidden_size = multi_modal_embeds.shape[-1]
multi_modal_embeds = multi_modal_embeds.view(-1, hidden_size)
n_modal_tokens = (input_ids == multi_modal_token_id).sum()
n_modal_embeds = multi_modal_embeds.shape[0]
if n_modal_tokens != n_modal_embeds:
raise ValueError(f"The number of {note} tokens ({n_modal_tokens}) does not match the number of {note} embeddings ({n_modal_embeds}).")
mask = input_ids == multi_modal_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(input_embeds)
video_mask = mask_expanded.to(input_embeds.device)
multi_modal_embeds = multi_modal_embeds.to(input_embeds.device, dtype=input_embeds.dtype)
input_embeds = input_embeds.masked_scatter(video_mask, multi_modal_embeds)
return input_embeds
def forward(self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
audio_values: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs
):
output_attentions = output_attentions if output_attentions is not None else self.config.llm_config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.llm_config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.llm_config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
if pixel_values_videos is not None:
video_features = self.get_video_features(pixel_values_videos)
inputs_embeds = self.replace_multi_modal_embeddings(video_features, inputs_embeds, input_ids, self.config.video_token_id, note="video")
if pixel_values is not None:
image_features = self.get_image_features(pixel_values)
inputs_embeds = self.replace_multi_modal_embeddings(image_features, inputs_embeds, input_ids, self.config.image_token_id, note="image")
if audio_values is not None:
audio_features = self.get_audio_features(audio_values)
inputs_embeds = self.replace_multi_modal_embeddings(audio_features, inputs_embeds, input_ids, self.config.audio_token_id, note="audio")
outputs = self.llm_model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs
)
return outputs
def prepare_inputs_for_generation(self,
input_ids,
past_key_values=None,
attention_mask=None,
use_cache=None,
pixel_values=None,
pixel_values_videos=None,
audio_values=None,
cache_position=None,
**kwargs):
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
use_cache=use_cache,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
audio_values=audio_values,
**kwargs
)
if cache_position[0] != 0:
model_inputs["pixel_values"] = None
model_inputs["pixel_values_videos"] = None
return model_inputs
def _expand_inputs_for_generation(self,
expand_size: int = 1,
is_encoder_decoder: bool = False,
input_ids: Optional[torch.LongTensor] = None,
**model_kwargs,
):
"""Expands input tensors for generation when using beam search or sampling.
Args:
expand_size (int, optional): The size to expand the inputs by. Defaults to 1.
is_encoder_decoder (bool, optional): Whether the model is an encoder-decoder model. Defaults to False.
input_ids (Optional[torch.LongTensor], optional): The input token IDs. Defaults to None.
**model_kwargs: Additional model-specific keyword arguments.
Returns:
Tuple[torch.LongTensor, Dict[str, torch.Tensor]]: The expanded input_ids and model_kwargs.
"""
if input_ids is not None:
input_ids = input_ids.repeat_interleave(expand_size, dim=0)
# Expand attention mask if present
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat_interleave(expand_size, dim=0)
# Expand position IDs if present
if "position_ids" in model_kwargs:
model_kwargs["position_ids"] = model_kwargs["position_ids"].repeat_interleave(expand_size, dim=0)
# Expand pixel values for images if present
if "pixel_values" in model_kwargs and model_kwargs["pixel_values"] is not None:
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].repeat_interleave(expand_size, dim=0)
# Expand pixel values for videos if present
if "pixel_values_videos" in model_kwargs and model_kwargs["pixel_values_videos"] is not None:
model_kwargs["pixel_values_videos"] = model_kwargs["pixel_values_videos"].repeat_interleave(expand_size, dim=0)
# Expand audio values if present
if "audio_values" in model_kwargs and model_kwargs["audio_values"] is not None:
model_kwargs["audio_values"] = model_kwargs["audio_values"].repeat_interleave(expand_size, dim=0)
# Expand cache position if present
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"].repeat_interleave(expand_size, dim=0)
return input_ids, model_kwargs