| | import logging |
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import einops |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PreTrainedModel |
| | from transformers.cache_utils import StaticCache |
| | from transformers.generation import GenerationMixin |
| | from transformers.generation.utils import GenerationConfig, GenerationMode |
| | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
| | from transformers.modeling_outputs import Seq2SeqLMOutput |
| | from transformers.models.hubert.modeling_hubert import ( |
| | HubertEncoder, |
| | HubertEncoderStableLayerNorm, |
| | ) |
| | from transformers.utils import ModelOutput |
| |
|
| | from .configuration_avhubert import AVHubertConfig |
| | from .configuration_resnet import ResEncoderConfig |
| | from .decoder import AVHubertDecoder, AVHubertDecoderStableLayerNorm |
| | from .modeling_resnet import ResEncoder |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | NEED_SETUP_CACHE_CLASSES_MAPPING = { |
| | "static": StaticCache, |
| | } |
| |
|
| |
|
| | @dataclass |
| | class AVHubertOutput: |
| | last_hidden_state: Optional[torch.Tensor] = None |
| | hidden_states: Optional[torch.Tensor] = None |
| | attentions: Optional[torch.Tensor] = None |
| |
|
| |
|
| | class AudioFeatureExtractor(nn.Module): |
| | def __init__(self, input_dim: int, output_dim: int) -> None: |
| | super(AudioFeatureExtractor, self).__init__() |
| | self.proj = nn.Linear(in_features=input_dim, out_features=output_dim) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.proj(x) |
| | return einops.rearrange(x, "b t f -> b f t") |
| |
|
| |
|
| | class VideoFeatureExtractor(nn.Module): |
| | def __init__(self, config: ResEncoderConfig, output_dim: int) -> None: |
| | super(VideoFeatureExtractor, self).__init__() |
| | self.resnet = ResEncoder(config=config) |
| | self.proj = nn.Linear( |
| | in_features=self.resnet.backend_out, |
| | out_features=output_dim, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.resnet(einops.rearrange(x, "b t c h w -> b c t h w")) |
| | x = self.proj(einops.rearrange(x, "b f t -> b t f")) |
| | return einops.rearrange(x, "b t f -> b f t") |
| |
|
| |
|
| | class AVHubertPreTrainedModel(PreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| | models. |
| | """ |
| |
|
| | config_class = AVHubertConfig |
| | base_model_prefix = "avhubert" |
| | supports_gradient_checkpointing = False |
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | if isinstance(module, (nn.Linear, nn.Embedding)): |
| | |
| | |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| | elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
| | if is_deepspeed_zero3_enabled(): |
| | import deepspeed |
| |
|
| | if hasattr(module, "weight_v") and hasattr(module, "weight_g"): |
| | with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): |
| | nn.init.kaiming_normal_(module.weight.data) |
| | else: |
| | with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): |
| | nn.init.kaiming_normal_(module.weight.data) |
| | else: |
| | if hasattr(module, "parametrizations"): |
| | nn.init.kaiming_normal_(module.parametrizations.weight.original0.data) |
| | nn.init.kaiming_normal_(module.parametrizations.weight.original1.data) |
| | nn.init.kaiming_normal_(module.weight.data) |
| |
|
| | if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)) and module.bias is not None: |
| | module.bias.data.zero_() |
| |
|
| | def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int): |
| | """ |
| | Computes the output length of the convolutional layers |
| | """ |
| |
|
| | def _conv_out_length(input_length, kernel_size, stride): |
| | |
| | |
| | return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 |
| |
|
| | for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): |
| | input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
| |
|
| | return input_lengths |
| |
|
| |
|
| | class AVHubertModel(AVHubertPreTrainedModel): |
| | def __init__(self, config: AVHubertConfig, **kwargs): |
| | super().__init__(config, **kwargs) |
| | self.config = config |
| | self.feat2tar_ratio = config.label_rate / config.sample_rate |
| |
|
| | |
| | resnet_config = ResEncoderConfig(relu_type=config.resnet_relu_type) |
| | self.feature_extractor_audio = AudioFeatureExtractor( |
| | input_dim=config.audio_feat_dim, |
| | output_dim=config.encoder_embed_dim, |
| | ) |
| | self.feature_extractor_video = VideoFeatureExtractor(config=resnet_config, output_dim=config.encoder_embed_dim) |
| |
|
| | self.encoder_embed_dim = config.encoder_embed_dim |
| | if config.modality_fuse == "concat": |
| | embed = config.encoder_embed_dim * 2 |
| | elif config.modality_fuse == "add": |
| | embed = config.encoder_embed_dim |
| | self.post_extract_proj = ( |
| | nn.Linear(embed, config.encoder_embed_dim) if embed != config.encoder_embed_dim else None |
| | ) |
| |
|
| | |
| | self.dropout_input = nn.Dropout(config.dropout_input) |
| |
|
| | |
| | transformer_config = config.encoder_config |
| | if transformer_config.do_stable_layer_norm: |
| | self.encoder = HubertEncoderStableLayerNorm(config=transformer_config) |
| | else: |
| | self.encoder = HubertEncoder(config=transformer_config) |
| | self.layer_norm = nn.LayerNorm(embed) |
| |
|
| | def forward_mask(self, features: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
| | extra = attention_mask.size(1) % features.size(1) |
| | if extra > 0: |
| | attention_mask = attention_mask[:, :-extra] |
| | attention_mask = attention_mask.view(attention_mask.size(0), features.size(1), -1) |
| | attention_mask = attention_mask.all(-1) |
| | return attention_mask |
| |
|
| | def forward( |
| | self, |
| | input_values: Optional[torch.Tensor] = None, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | padding_mask: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | **kwargs, |
| | ) -> ModelOutput: |
| | if input_values is not None and pixel_values is None: |
| | features_audio = self.feature_extractor_audio(input_values) |
| | features_video = torch.zeros_like(features_audio) |
| | elif input_values is None and pixel_values is not None: |
| | features_video = self.feature_extractor_video(pixel_values) |
| | features_audio = torch.zeros_like(features_video) |
| | elif input_values is not None and pixel_values is not None: |
| | features_audio = self.feature_extractor_audio(input_values) |
| | features_video = self.feature_extractor_video(pixel_values) |
| | else: |
| | raise ValueError("Either `input_values` or `pixel_values` must be passed") |
| |
|
| | |
| | if self.config.modality_fuse == "concat": |
| | features = torch.cat([features_audio, features_video], dim=1) |
| | elif self.config.modality_fuse == "add": |
| | features = features_audio + features_video |
| |
|
| | features = features.transpose(1, 2) |
| | features = self.layer_norm(features) |
| |
|
| | if padding_mask is not None: |
| | padding_mask = self.forward_mask(features, padding_mask) |
| | else: |
| | padding_mask = torch.zeros(features.size()[:2], dtype=torch.bool, device=features.device) |
| |
|
| | if self.post_extract_proj is not None: |
| | features = self.post_extract_proj(features) |
| |
|
| | features = self.dropout_input(features) |
| |
|
| | |
| | encoder_out = self.encoder( |
| | hidden_states=features, |
| | attention_mask=~padding_mask.bool(), |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| |
|
| | return AVHubertOutput( |
| | last_hidden_state=encoder_out.last_hidden_state, |
| | hidden_states=encoder_out.hidden_states, |
| | attentions=encoder_out.attentions, |
| | ) |
| |
|
| |
|
| | class AVHubertForConditionalGeneration(AVHubertPreTrainedModel, GenerationMixin): |
| | def __init__( |
| | self, |
| | config: AVHubertConfig, |
| | **kwargs, |
| | ) -> None: |
| | super().__init__(config=config, **kwargs) |
| | self.config = config |
| |
|
| | self.avhubert = AVHubertModel(config=config) |
| | if config.freeze_base_model: |
| | self.freeze_base_model() |
| | if config.freeze_feature_encoder: |
| | self.freeze_feature_encoder() |
| |
|
| | if config.vocab_size is None: |
| | raise ValueError( |
| | f"You are trying to instantiate {self.__class__} with a configuration that " |
| | "does not define the vocabulary size of the language model head. Please " |
| | "instantiate the model as follows: `AVHubertForCTC.from_pretrained(..., vocab_size=vocab_size)`. " |
| | "or define `vocab_size` of your model's configuration." |
| | ) |
| |
|
| | self.embed_tokens = nn.Embedding(config.vocab_size, config.decoder_embed_dim, padding_idx=config.pad_token_id) |
| | transformer_config = config.decoder_config |
| | if transformer_config.do_stable_layer_norm: |
| | self.decoder = AVHubertDecoderStableLayerNorm(config=transformer_config) |
| | else: |
| | self.decoder = AVHubertDecoder(config=transformer_config) |
| |
|
| | self.lm_head = nn.Linear(config.decoder_embed_dim, config.vocab_size, bias=False) |
| | if config.share_decoder_input_output_embed: |
| | |
| | |
| | |
| | self.lm_head.weight = self.embed_tokens.weight |
| | else: |
| | nn.init.normal_(self.lm_head.weight, mean=0, std=config.decoder_embed_dim**-0.5) |
| |
|
| | self.post_init() |
| |
|
| | def freeze_feature_encoder(self): |
| | """ |
| | Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
| | not be updated during training. |
| | """ |
| | for param in self.avhubert.feature_extractor_audio.parameters(): |
| | param.requires_grad = False |
| | for param in self.avhubert.feature_extractor_video.parameters(): |
| | param.requires_grad = False |
| |
|
| | def freeze_base_model(self): |
| | """ |
| | Calling this function will disable the gradient computation for the base model so that its parameters will not |
| | be updated during training. Only the classification head will be updated. |
| | """ |
| | for param in self.avhubert.parameters(): |
| | param.requires_grad = False |
| |
|
| | def get_encoder(self): |
| | return self.avhubert |
| |
|
| | def forward( |
| | self, |
| | input_values: Optional[torch.Tensor] = None, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | padding_mask: Optional[torch.Tensor] = None, |
| | decoder_input_ids: Optional[torch.Tensor] = None, |
| | decoder_attention_mask: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | return_dict: bool = True, |
| | ) -> ModelOutput: |
| | encoder_outs = self.avhubert( |
| | input_values=input_values, |
| | pixel_values=pixel_values, |
| | padding_mask=padding_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| |
|
| | embed_tokens = self.embed_tokens(decoder_input_ids) |
| | hidden_states = self.decoder( |
| | inputs_embeds=embed_tokens, |
| | attention_mask=decoder_attention_mask, |
| | encoder_hidden_states=encoder_outs.last_hidden_state, |
| | encoder_attention_mask=~padding_mask.bool(), |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| |
|
| | if self.config.share_decoder_input_output_embed: |
| | logits = F.linear(hidden_states.last_hidden_state, weight=self.embed_tokens.weight) |
| | else: |
| | logits = self.lm_head(hidden_states.last_hidden_state) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) |
| | loss = loss_fn(logits.view(-1, self.config.vocab_size), labels.reshape(-1)) |
| |
|
| | return Seq2SeqLMOutput( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=None, |
| | decoder_hidden_states=hidden_states.hidden_states, |
| | decoder_attentions=hidden_states.attentions, |
| | cross_attentions=None, |
| | encoder_last_hidden_state=encoder_outs.last_hidden_state, |
| | encoder_hidden_states=encoder_outs.hidden_states, |
| | encoder_attentions=encoder_outs.attentions, |
| | ) |
| |
|
| | def _get_generation_mode( |
| | self, |
| | generation_config: GenerationConfig, |
| | assistant_model: PreTrainedModel | None, |
| | ) -> GenerationMode: |
| | """ |
| | Returns the generation mode triggered by a [`GenerationConfig`] instance. |
| | """ |
| | if generation_config.constraints is not None or generation_config.force_words_ids is not None: |
| | generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH |
| | elif generation_config.num_beams == 1: |
| | if generation_config.do_sample is False: |
| | if ( |
| | generation_config.top_k is not None |
| | and generation_config.top_k > 1 |
| | and generation_config.penalty_alpha is not None |
| | and generation_config.penalty_alpha > 0 |
| | ): |
| | generation_mode = GenerationMode.CONTRASTIVE_SEARCH |
| | else: |
| | generation_mode = GenerationMode.GREEDY_SEARCH |
| | else: |
| | generation_mode = GenerationMode.SAMPLE |
| | else: |
| | if generation_config.num_beam_groups > 1: |
| | generation_mode = GenerationMode.GROUP_BEAM_SEARCH |
| | elif generation_config.do_sample is True: |
| | generation_mode = GenerationMode.BEAM_SAMPLE |
| | else: |
| | generation_mode = GenerationMode.BEAM_SEARCH |
| |
|
| | |
| | if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None: |
| | if generation_mode in ("greedy_search", "sample"): |
| | generation_mode = GenerationMode.ASSISTED_GENERATION |
| | else: |
| | raise ValueError( |
| | "You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " |
| | "is only supported with Greedy Search and Sample." |
| | ) |
| | return generation_mode |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids: torch.Tensor = None, |
| | input_values: Optional[torch.Tensor] = None, |
| | pixel_values: Optional[torch.Tensor] = None, |
| | decoder_input_ids: Optional[torch.Tensor] = None, |
| | decoder_attention_mask: Optional[torch.Tensor] = None, |
| | padding_mask: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ): |
| | if decoder_input_ids is None: |
| | decoder_input_ids = input_ids |
| | decoder_attention_mask = torch.ones_like(input_ids) |
| | return { |
| | "input_values": input_values, |
| | "pixel_values": pixel_values, |
| | "decoder_input_ids": decoder_input_ids, |
| | "decoder_attention_mask": decoder_attention_mask, |
| | "padding_mask": padding_mask, |
| | } |
| |
|