|
|
import logging |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
import einops |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.cache_utils import StaticCache |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.generation.utils import GenerationConfig, GenerationMode |
|
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
|
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
|
from transformers.models.hubert.modeling_hubert import ( |
|
|
HubertEncoder, |
|
|
HubertEncoderStableLayerNorm, |
|
|
) |
|
|
from transformers.utils import ModelOutput |
|
|
|
|
|
from .configuration_avhubert import AVHubertConfig |
|
|
from .configuration_resnet import ResEncoderConfig |
|
|
from .decoder import AVHubertDecoder, AVHubertDecoderStableLayerNorm |
|
|
from .modeling_resnet import ResEncoder |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
NEED_SETUP_CACHE_CLASSES_MAPPING = { |
|
|
"static": StaticCache, |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AVHubertOutput: |
|
|
last_hidden_state: Optional[torch.Tensor] = None |
|
|
hidden_states: Optional[torch.Tensor] = None |
|
|
attentions: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
class AudioFeatureExtractor(nn.Module): |
|
|
def __init__(self, input_dim: int, output_dim: int) -> None: |
|
|
super(AudioFeatureExtractor, self).__init__() |
|
|
self.proj = nn.Linear(in_features=input_dim, out_features=output_dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.proj(x) |
|
|
return einops.rearrange(x, "b t f -> b f t") |
|
|
|
|
|
|
|
|
class VideoFeatureExtractor(nn.Module): |
|
|
def __init__(self, config: ResEncoderConfig, output_dim: int) -> None: |
|
|
super(VideoFeatureExtractor, self).__init__() |
|
|
self.resnet = ResEncoder(config=config) |
|
|
self.proj = nn.Linear( |
|
|
in_features=self.resnet.backend_out, |
|
|
out_features=output_dim, |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.resnet(einops.rearrange(x, "b t c h w -> b c t h w")) |
|
|
x = self.proj(einops.rearrange(x, "b f t -> b t f")) |
|
|
return einops.rearrange(x, "b t f -> b f t") |
|
|
|
|
|
|
|
|
class AVHubertPreTrainedModel(PreTrainedModel): |
|
|
""" |
|
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
|
models. |
|
|
""" |
|
|
|
|
|
config_class = AVHubertConfig |
|
|
base_model_prefix = "avhubert" |
|
|
supports_gradient_checkpointing = False |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights""" |
|
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
|
|
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
elif isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
|
|
if is_deepspeed_zero3_enabled(): |
|
|
import deepspeed |
|
|
|
|
|
if hasattr(module, "weight_v") and hasattr(module, "weight_g"): |
|
|
with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): |
|
|
nn.init.kaiming_normal_(module.weight.data) |
|
|
else: |
|
|
with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): |
|
|
nn.init.kaiming_normal_(module.weight.data) |
|
|
else: |
|
|
if hasattr(module, "parametrizations"): |
|
|
nn.init.kaiming_normal_(module.parametrizations.weight.original0.data) |
|
|
nn.init.kaiming_normal_(module.parametrizations.weight.original1.data) |
|
|
nn.init.kaiming_normal_(module.weight.data) |
|
|
|
|
|
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)) and module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
|
|
|
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor | int): |
|
|
""" |
|
|
Computes the output length of the convolutional layers |
|
|
""" |
|
|
|
|
|
def _conv_out_length(input_length, kernel_size, stride): |
|
|
|
|
|
|
|
|
return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 |
|
|
|
|
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): |
|
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
|
|
|
return input_lengths |
|
|
|
|
|
|
|
|
class AVHubertModel(AVHubertPreTrainedModel): |
|
|
def __init__(self, config: AVHubertConfig, **kwargs): |
|
|
super().__init__(config, **kwargs) |
|
|
self.config = config |
|
|
self.feat2tar_ratio = config.label_rate / config.sample_rate |
|
|
|
|
|
|
|
|
resnet_config = ResEncoderConfig(relu_type=config.resnet_relu_type) |
|
|
self.feature_extractor_audio = AudioFeatureExtractor( |
|
|
input_dim=config.audio_feat_dim, |
|
|
output_dim=config.encoder_embed_dim, |
|
|
) |
|
|
self.feature_extractor_video = VideoFeatureExtractor(config=resnet_config, output_dim=config.encoder_embed_dim) |
|
|
|
|
|
self.encoder_embed_dim = config.encoder_embed_dim |
|
|
if config.modality_fuse == "concat": |
|
|
embed = config.encoder_embed_dim * 2 |
|
|
elif config.modality_fuse == "add": |
|
|
embed = config.encoder_embed_dim |
|
|
self.post_extract_proj = ( |
|
|
nn.Linear(embed, config.encoder_embed_dim) if embed != config.encoder_embed_dim else None |
|
|
) |
|
|
|
|
|
|
|
|
self.dropout_input = nn.Dropout(config.dropout_input) |
|
|
|
|
|
|
|
|
transformer_config = config.encoder_config |
|
|
if transformer_config.do_stable_layer_norm: |
|
|
self.encoder = HubertEncoderStableLayerNorm(config=transformer_config) |
|
|
else: |
|
|
self.encoder = HubertEncoder(config=transformer_config) |
|
|
self.layer_norm = nn.LayerNorm(embed) |
|
|
|
|
|
def forward_mask(self, features: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
|
extra = attention_mask.size(1) % features.size(1) |
|
|
if extra > 0: |
|
|
attention_mask = attention_mask[:, :-extra] |
|
|
attention_mask = attention_mask.view(attention_mask.size(0), features.size(1), -1) |
|
|
attention_mask = attention_mask.all(-1) |
|
|
return attention_mask |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: Optional[torch.Tensor] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
padding_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
**kwargs, |
|
|
) -> ModelOutput: |
|
|
if input_values is not None and pixel_values is None: |
|
|
features_audio = self.feature_extractor_audio(input_values) |
|
|
features_video = torch.zeros_like(features_audio) |
|
|
elif input_values is None and pixel_values is not None: |
|
|
features_video = self.feature_extractor_video(pixel_values) |
|
|
features_audio = torch.zeros_like(features_video) |
|
|
elif input_values is not None and pixel_values is not None: |
|
|
features_audio = self.feature_extractor_audio(input_values) |
|
|
features_video = self.feature_extractor_video(pixel_values) |
|
|
else: |
|
|
raise ValueError("Either `input_values` or `pixel_values` must be passed") |
|
|
|
|
|
|
|
|
if self.config.modality_fuse == "concat": |
|
|
features = torch.cat([features_audio, features_video], dim=1) |
|
|
elif self.config.modality_fuse == "add": |
|
|
features = features_audio + features_video |
|
|
|
|
|
features = features.transpose(1, 2) |
|
|
features = self.layer_norm(features) |
|
|
|
|
|
if padding_mask is not None: |
|
|
padding_mask = self.forward_mask(features, padding_mask) |
|
|
else: |
|
|
padding_mask = torch.zeros(features.size()[:2], dtype=torch.bool, device=features.device) |
|
|
|
|
|
if self.post_extract_proj is not None: |
|
|
features = self.post_extract_proj(features) |
|
|
|
|
|
features = self.dropout_input(features) |
|
|
|
|
|
|
|
|
encoder_out = self.encoder( |
|
|
hidden_states=features, |
|
|
attention_mask=~padding_mask.bool(), |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
) |
|
|
|
|
|
return AVHubertOutput( |
|
|
last_hidden_state=encoder_out.last_hidden_state, |
|
|
hidden_states=encoder_out.hidden_states, |
|
|
attentions=encoder_out.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class AVHubertForConditionalGeneration(AVHubertPreTrainedModel, GenerationMixin): |
|
|
def __init__( |
|
|
self, |
|
|
config: AVHubertConfig, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
super().__init__(config=config, **kwargs) |
|
|
self.config = config |
|
|
|
|
|
self.avhubert = AVHubertModel(config=config) |
|
|
if config.freeze_base_model: |
|
|
self.freeze_base_model() |
|
|
if config.freeze_feature_encoder: |
|
|
self.freeze_feature_encoder() |
|
|
|
|
|
if config.vocab_size is None: |
|
|
raise ValueError( |
|
|
f"You are trying to instantiate {self.__class__} with a configuration that " |
|
|
"does not define the vocabulary size of the language model head. Please " |
|
|
"instantiate the model as follows: `AVHubertForCTC.from_pretrained(..., vocab_size=vocab_size)`. " |
|
|
"or define `vocab_size` of your model's configuration." |
|
|
) |
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.decoder_embed_dim, padding_idx=config.pad_token_id) |
|
|
transformer_config = config.decoder_config |
|
|
if transformer_config.do_stable_layer_norm: |
|
|
self.decoder = AVHubertDecoderStableLayerNorm(config=transformer_config) |
|
|
else: |
|
|
self.decoder = AVHubertDecoder(config=transformer_config) |
|
|
|
|
|
self.lm_head = nn.Linear(config.decoder_embed_dim, config.vocab_size, bias=False) |
|
|
if config.share_decoder_input_output_embed: |
|
|
|
|
|
|
|
|
|
|
|
self.lm_head.weight = self.embed_tokens.weight |
|
|
else: |
|
|
nn.init.normal_(self.lm_head.weight, mean=0, std=config.decoder_embed_dim**-0.5) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def freeze_feature_encoder(self): |
|
|
""" |
|
|
Calling this function will disable the gradient computation for the feature encoder so that its parameter will |
|
|
not be updated during training. |
|
|
""" |
|
|
for param in self.avhubert.feature_extractor_audio.parameters(): |
|
|
param.requires_grad = False |
|
|
for param in self.avhubert.feature_extractor_video.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def freeze_base_model(self): |
|
|
""" |
|
|
Calling this function will disable the gradient computation for the base model so that its parameters will not |
|
|
be updated during training. Only the classification head will be updated. |
|
|
""" |
|
|
for param in self.avhubert.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def get_encoder(self): |
|
|
return self.avhubert |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: Optional[torch.Tensor] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
padding_mask: Optional[torch.Tensor] = None, |
|
|
decoder_input_ids: Optional[torch.Tensor] = None, |
|
|
decoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: bool = False, |
|
|
output_hidden_states: bool = False, |
|
|
return_dict: bool = True, |
|
|
) -> ModelOutput: |
|
|
encoder_outs = self.avhubert( |
|
|
input_values=input_values, |
|
|
pixel_values=pixel_values, |
|
|
padding_mask=padding_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
) |
|
|
|
|
|
embed_tokens = self.embed_tokens(decoder_input_ids) |
|
|
hidden_states = self.decoder( |
|
|
inputs_embeds=embed_tokens, |
|
|
attention_mask=decoder_attention_mask, |
|
|
encoder_hidden_states=encoder_outs.last_hidden_state, |
|
|
encoder_attention_mask=~padding_mask.bool(), |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
) |
|
|
|
|
|
if self.config.share_decoder_input_output_embed: |
|
|
logits = F.linear(hidden_states.last_hidden_state, weight=self.embed_tokens.weight) |
|
|
else: |
|
|
logits = self.lm_head(hidden_states.last_hidden_state) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) |
|
|
loss = loss_fn(logits.view(-1, self.config.vocab_size), labels.reshape(-1)) |
|
|
|
|
|
return Seq2SeqLMOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=None, |
|
|
decoder_hidden_states=hidden_states.hidden_states, |
|
|
decoder_attentions=hidden_states.attentions, |
|
|
cross_attentions=None, |
|
|
encoder_last_hidden_state=encoder_outs.last_hidden_state, |
|
|
encoder_hidden_states=encoder_outs.hidden_states, |
|
|
encoder_attentions=encoder_outs.attentions, |
|
|
) |
|
|
|
|
|
def _get_generation_mode( |
|
|
self, |
|
|
generation_config: GenerationConfig, |
|
|
assistant_model: PreTrainedModel | None, |
|
|
) -> GenerationMode: |
|
|
""" |
|
|
Returns the generation mode triggered by a [`GenerationConfig`] instance. |
|
|
""" |
|
|
if generation_config.constraints is not None or generation_config.force_words_ids is not None: |
|
|
generation_mode = GenerationMode.CONSTRAINED_BEAM_SEARCH |
|
|
elif generation_config.num_beams == 1: |
|
|
if generation_config.do_sample is False: |
|
|
if ( |
|
|
generation_config.top_k is not None |
|
|
and generation_config.top_k > 1 |
|
|
and generation_config.penalty_alpha is not None |
|
|
and generation_config.penalty_alpha > 0 |
|
|
): |
|
|
generation_mode = GenerationMode.CONTRASTIVE_SEARCH |
|
|
else: |
|
|
generation_mode = GenerationMode.GREEDY_SEARCH |
|
|
else: |
|
|
generation_mode = GenerationMode.SAMPLE |
|
|
else: |
|
|
if generation_config.num_beam_groups > 1: |
|
|
generation_mode = GenerationMode.GROUP_BEAM_SEARCH |
|
|
elif generation_config.do_sample is True: |
|
|
generation_mode = GenerationMode.BEAM_SAMPLE |
|
|
else: |
|
|
generation_mode = GenerationMode.BEAM_SEARCH |
|
|
|
|
|
|
|
|
if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None: |
|
|
if generation_mode in ("greedy_search", "sample"): |
|
|
generation_mode = GenerationMode.ASSISTED_GENERATION |
|
|
else: |
|
|
raise ValueError( |
|
|
"You've set `assistant_model`, which triggers assisted generate. Currently, assisted generate " |
|
|
"is only supported with Greedy Search and Sample." |
|
|
) |
|
|
return generation_mode |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids: torch.Tensor = None, |
|
|
input_values: Optional[torch.Tensor] = None, |
|
|
pixel_values: Optional[torch.Tensor] = None, |
|
|
decoder_input_ids: Optional[torch.Tensor] = None, |
|
|
decoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
padding_mask: Optional[torch.Tensor] = None, |
|
|
**kwargs, |
|
|
): |
|
|
if decoder_input_ids is None: |
|
|
decoder_input_ids = input_ids |
|
|
decoder_attention_mask = torch.ones_like(input_ids) |
|
|
return { |
|
|
"input_values": input_values, |
|
|
"pixel_values": pixel_values, |
|
|
"decoder_input_ids": decoder_input_ids, |
|
|
"decoder_attention_mask": decoder_attention_mask, |
|
|
"padding_mask": padding_mask, |
|
|
} |
|
|
|