diff --git a/docs/transformers/build/lib/transformers/models/visual_bert/configuration_visual_bert.py b/docs/transformers/build/lib/transformers/models/visual_bert/configuration_visual_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..a866227d3470b9b63f5ca7f2afad458b90cbd11d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/visual_bert/configuration_visual_bert.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VisualBERT model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class VisualBertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VisualBertModel`]. It is used to instantiate an + VisualBERT model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the VisualBERT + [uclanlp/visualbert-vqa-coco-pre](https://huggingface.co/uclanlp/visualbert-vqa-coco-pre) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the VisualBERT model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`VisualBertModel`]. Vocabulary size of the model. Defines the + different tokens that can be represented by the `inputs_ids` passed to the forward method of + [`VisualBertModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + visual_embedding_dim (`int`, *optional*, defaults to 512): + Dimensionality of the visual embeddings to be passed to the model. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`VisualBertModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + bypass_transformer (`bool`, *optional*, defaults to `False`): + Whether or not the model should bypass the transformer for the visual embeddings. If set to `True`, the + model directly concatenates the visual embeddings from [`VisualBertEmbeddings`] with text output from + transformers, and then pass it to a self-attention layer. + special_visual_initialize (`bool`, *optional*, defaults to `True`): + Whether or not the visual token type and position type embedding weights should be initialized the same as + the textual token type and positive type embeddings. When set to `True`, the weights of the textual token + type and position type embeddings are copied to the respective visual embedding layers. + + + Example: + + ```python + >>> from transformers import VisualBertConfig, VisualBertModel + + >>> # Initializing a VisualBERT visualbert-vqa-coco-pre style configuration + >>> configuration = VisualBertConfig.from_pretrained("uclanlp/visualbert-vqa-coco-pre") + + >>> # Initializing a model (with random weights) from the visualbert-vqa-coco-pre style configuration + >>> model = VisualBertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "visual_bert" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + visual_embedding_dim=512, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + bypass_transformer=False, + special_visual_initialize=True, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.visual_embedding_dim = visual_embedding_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.type_vocab_size = type_vocab_size + self.layer_norm_eps = layer_norm_eps + self.bypass_transformer = bypass_transformer + self.special_visual_initialize = special_visual_initialize + + +__all__ = ["VisualBertConfig"] diff --git a/docs/transformers/build/lib/transformers/models/visual_bert/modeling_visual_bert.py b/docs/transformers/build/lib/transformers/models/visual_bert/modeling_visual_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..2d81258a89c4118412f43d88fb8cf49ad67a9de6 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/visual_bert/modeling_visual_bert.py @@ -0,0 +1,1597 @@ +# coding=utf-8 +# Copyright 2021 The UCLA NLP Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch VisualBERT model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MultipleChoiceModelOutput, + SequenceClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_visual_bert import VisualBertConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "VisualBertConfig" +_CHECKPOINT_FOR_DOC = "uclanlp/visualbert-vqa-coco-pre" + + +class VisualBertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings and visual embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + # For Visual Features + # Token type and position embedding for image features + self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + if config.special_visual_initialize: + self.visual_token_type_embeddings.weight.data = nn.Parameter( + self.token_type_embeddings.weight.data.clone(), requires_grad=True + ) + self.visual_position_embeddings.weight.data = nn.Parameter( + self.position_embeddings.weight.data.clone(), requires_grad=True + ) + + self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size) + + def forward( + self, + input_ids=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + visual_embeds=None, + visual_token_type_ids=None, + image_text_alignment=None, + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + + # Absolute Position Embeddings + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + if visual_embeds is not None: + if visual_token_type_ids is None: + visual_token_type_ids = torch.ones( + visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device + ) + + visual_embeds = self.visual_projection(visual_embeds) + visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids) + + if image_text_alignment is not None: + # image_text_alignment = Batch x image_length x alignment_number. + # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value. + + dtype = token_type_embeddings.dtype + image_text_alignment_mask = (image_text_alignment != -1).long() + # Get rid of the -1. + image_text_alignment = image_text_alignment_mask * image_text_alignment + + # Batch x image_length x alignment length x dim + visual_position_embeddings = self.position_embeddings(image_text_alignment) + visual_position_embeddings *= image_text_alignment_mask.to(dtype=dtype).unsqueeze(-1) + visual_position_embeddings = visual_position_embeddings.sum(2) + + # We want to averge along the alignment_number dimension. + image_text_alignment_mask = image_text_alignment_mask.to(dtype=dtype).sum(2) + + if (image_text_alignment_mask == 0).sum() != 0: + image_text_alignment_mask[image_text_alignment_mask == 0] = 1 # Avoid divide by zero error + logger.warning( + "Found 0 values in `image_text_alignment_mask`. Setting them to 1 to avoid divide-by-zero" + " error." + ) + visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1) + + visual_position_ids = torch.zeros( + *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device + ) + + # When fine-tuning the detector , the image_text_alignment is sometimes padded too long. + if visual_position_embeddings.size(1) != visual_embeds.size(1): + if visual_position_embeddings.size(1) < visual_embeds.size(1): + raise ValueError( + f"Visual position embeddings length: {visual_position_embeddings.size(1)} " + f"should be the same as `visual_embeds` length: {visual_embeds.size(1)}" + ) + visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :] + + visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings( + visual_position_ids + ) + else: + visual_position_ids = torch.zeros( + *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device + ) + visual_position_embeddings = self.visual_position_embeddings(visual_position_ids) + + visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings + + embeddings = torch.cat((embeddings, visual_embeddings), dim=1) + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class VisualBertSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in VisualBertSelfAttentionModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->VisualBert +class VisualBertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class VisualBertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = VisualBertSelfAttention(config) + self.output = VisualBertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->VisualBert +class VisualBertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->VisualBert +class VisualBertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class VisualBertLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = VisualBertAttention(config) + self.intermediate = VisualBertIntermediate(config) + self.output = VisualBertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + ): + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class VisualBertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + all_hidden_states, + all_self_attentions, + ] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->VisualBert +class VisualBertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->VisualBert +class VisualBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->VisualBert +class VisualBertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = VisualBertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def _tie_weights(self): + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->VisualBert +class VisualBertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = VisualBertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class VisualBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = VisualBertConfig + base_model_prefix = "visual_bert" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, VisualBertLMPredictionHead): + module.bias.data.zero_() + + +@dataclass +class VisualBertForPreTrainingOutput(ModelOutput): + """ + Output type of [`VisualBertForPreTraining`]. + + Args: + loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the masked language modeling loss and the sentence-image prediction + (classification) loss. + prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): + Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation + before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + prediction_logits: Optional[torch.FloatTensor] = None + seq_relationship_logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +VISUAL_BERT_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`VisualBertConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VISUAL_BERT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + visual_embeds (`torch.FloatTensor` of shape `(batch_size, visual_seq_length, visual_embedding_dim)`, *optional*): + The embedded representation of the visual inputs, generally derived using using an object detector. + + visual_attention_mask (`torch.FloatTensor` of shape `(batch_size, visual_seq_length)`, *optional*): + Mask to avoid performing attention on visual embeddings. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + visual_token_type_ids (`torch.LongTensor` of shape `(batch_size, visual_seq_length)`, *optional*): + Segment token indices to indicate different portions of the visual embeds. + + [What are token type IDs?](../glossary#token-type-ids) The authors of VisualBERT set the + *visual_token_type_ids* to *1* for all tokens. + + image_text_alignment (`torch.LongTensor` of shape `(batch_size, visual_seq_length, alignment_number)`, *optional*): + Image-Text alignment uses to decide the position IDs of the visual embeddings. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.", + VISUAL_BERT_START_DOCSTRING, +) +class VisualBertModel(VisualBertPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) following the architecture described in [Attention is + all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = VisualBertEmbeddings(config) + self.encoder = VisualBertEncoder(config) + + self.pooler = VisualBertPooler(config) if add_pooling_layer else None + + self.bypass_transformer = config.bypass_transformer + + if self.bypass_transformer: + self.additional_layer = VisualBertLayer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + visual_embeds: Optional[torch.FloatTensor] = None, + visual_attention_mask: Optional[torch.LongTensor] = None, + visual_token_type_ids: Optional[torch.LongTensor] = None, + image_text_alignment: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]: + r""" + + Returns: + + Example: + + ```python + # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image. + from transformers import AutoTokenizer, VisualBertModel + import torch + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + model = VisualBertModel.from_pretrained("uclanlp/visualbert-vqa-coco-pre") + + inputs = tokenizer("The capital of France is Paris.", return_tensors="pt") + visual_embeds = get_visual_embeddings(image).unsqueeze(0) + visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) + visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float) + + inputs.update( + { + "visual_embeds": visual_embeds, + "visual_token_type_ids": visual_token_type_ids, + "visual_attention_mask": visual_attention_mask, + } + ) + + outputs = model(**inputs) + + last_hidden_states = outputs.last_hidden_state + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if visual_embeds is not None: + visual_input_shape = visual_embeds.size()[:-1] + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + + if visual_embeds is not None and visual_attention_mask is None: + visual_attention_mask = torch.ones(visual_input_shape, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if visual_embeds is not None: + combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + combined_attention_mask, (batch_size, input_shape + visual_input_shape) + ) + + else: + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, (batch_size, input_shape) + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + visual_embeds=visual_embeds, + visual_token_type_ids=visual_token_type_ids, + image_text_alignment=image_text_alignment, + ) + + if self.bypass_transformer and visual_embeds is not None: + text_length = input_ids.size(1) + text_embedding_output = embedding_output[:, :text_length, :] + visual_embedding_output = embedding_output[:, text_length:, :] + + text_extended_attention_mask = extended_attention_mask[:, :, text_length, :text_length] + + encoded_outputs = self.encoder( + text_embedding_output, + attention_mask=text_extended_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoded_outputs[0] + concatenated_input = torch.cat((sequence_output, visual_embedding_output), dim=1) + sequence_output = self.additional_layer(concatenated_input, extended_attention_mask) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + else: + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """ + VisualBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a + `sentence-image prediction (classification)` head. + """, + VISUAL_BERT_START_DOCSTRING, +) +class VisualBertForPreTraining(VisualBertPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.visual_bert = VisualBertModel(config) + self.cls = VisualBertPreTrainingHeads(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias + + @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + visual_embeds: Optional[torch.FloatTensor] = None, + visual_attention_mask: Optional[torch.LongTensor] = None, + visual_token_type_ids: Optional[torch.LongTensor] = None, + image_text_alignment: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + sentence_image_labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.Tensor], VisualBertForPreTrainingOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, total_sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + sentence_image_labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sentence-image prediction (classification) loss. Input should be a sequence pair + (see `input_ids` docstring) Indices should be in `[0, 1]`: + + - 0 indicates sequence B is a matching pair of sequence A for the given image, + - 1 indicates sequence B is a random sequence w.r.t A for the given image. + + Returns: + + Example: + + ```python + # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch. + from transformers import AutoTokenizer, VisualBertForPreTraining + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + model = VisualBertForPreTraining.from_pretrained("uclanlp/visualbert-vqa-coco-pre") + + inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt") + visual_embeds = get_visual_embeddings(image).unsqueeze(0) + visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) + visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float) + + inputs.update( + { + "visual_embeds": visual_embeds, + "visual_token_type_ids": visual_token_type_ids, + "visual_attention_mask": visual_attention_mask, + } + ) + max_length = inputs["input_ids"].shape[-1] + visual_embeds.shape[-2] + labels = tokenizer( + "The capital of France is Paris.", return_tensors="pt", padding="max_length", max_length=max_length + )["input_ids"] + sentence_image_labels = torch.tensor(1).unsqueeze(0) # Batch_size + + + outputs = model(**inputs, labels=labels, sentence_image_labels=sentence_image_labels) + loss = outputs.loss + prediction_logits = outputs.prediction_logits + seq_relationship_logits = outputs.seq_relationship_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + total_size = attention_mask.size(-1) + visual_attention_mask.size(-1) + if labels.size(-1) != total_size: + raise ValueError( + "The labels provided should have same sequence length as total attention mask. " + f"Found labels with sequence length {labels.size(-1)}, expected {total_size}." + ) + + outputs = self.visual_bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + visual_embeds=visual_embeds, + visual_attention_mask=visual_attention_mask, + visual_token_type_ids=visual_token_type_ids, + image_text_alignment=image_text_alignment, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and sentence_image_labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + sentence_image_loss = loss_fct(seq_relationship_score.view(-1, 2), sentence_image_labels.view(-1)) + total_loss = masked_lm_loss + sentence_image_loss + + elif labels is not None: + loss_fct = CrossEntropyLoss() + total_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return VisualBertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + VisualBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and + a softmax) e.g. for VCR tasks. + """, + VISUAL_BERT_START_DOCSTRING, +) +class VisualBertForMultipleChoice(VisualBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.visual_bert = VisualBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.cls = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward( + VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @replace_return_docstrings(output_type=MultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + visual_embeds: Optional[torch.FloatTensor] = None, + visual_attention_mask: Optional[torch.LongTensor] = None, + visual_token_type_ids: Optional[torch.LongTensor] = None, + image_text_alignment: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + + Returns: + + Example: + + ```python + # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch. + from transformers import AutoTokenizer, VisualBertForMultipleChoice + import torch + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + model = VisualBertForMultipleChoice.from_pretrained("uclanlp/visualbert-vcr") + + prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." + choice0 = "It is eaten with a fork and a knife." + choice1 = "It is eaten while held in the hand." + + visual_embeds = get_visual_embeddings(image) + # (batch_size, num_choices, visual_seq_length, visual_embedding_dim) + visual_embeds = visual_embeds.expand(1, 2, *visual_embeds.shape) + visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) + visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float) + + labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1 + + encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors="pt", padding=True) + # batch size is 1 + inputs_dict = {k: v.unsqueeze(0) for k, v in encoding.items()} + inputs_dict.update( + { + "visual_embeds": visual_embeds, + "visual_attention_mask": visual_attention_mask, + "visual_token_type_ids": visual_token_type_ids, + "labels": labels, + } + ) + outputs = model(**inputs_dict) + + loss = outputs.loss + logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + visual_embeds = ( + visual_embeds.view(-1, visual_embeds.size(-2), visual_embeds.size(-1)) + if visual_embeds is not None + else None + ) + visual_attention_mask = ( + visual_attention_mask.view(-1, visual_attention_mask.size(-1)) + if visual_attention_mask is not None + else None + ) + visual_token_type_ids = ( + visual_token_type_ids.view(-1, visual_token_type_ids.size(-1)) + if visual_token_type_ids is not None + else None + ) + + outputs = self.visual_bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + visual_embeds=visual_embeds, + visual_attention_mask=visual_attention_mask, + visual_token_type_ids=visual_token_type_ids, + image_text_alignment=image_text_alignment, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + _, pooled_output = outputs[0], outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.cls(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + VisualBert Model with a classification/regression head on top (a dropout and a linear layer on top of the pooled + output) for VQA. + """, + VISUAL_BERT_START_DOCSTRING, +) +class VisualBertForQuestionAnswering(VisualBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.visual_bert = VisualBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.cls = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + visual_embeds: Optional[torch.FloatTensor] = None, + visual_attention_mask: Optional[torch.LongTensor] = None, + visual_token_type_ids: Optional[torch.LongTensor] = None, + image_text_alignment: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, total_sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. A KLDivLoss is computed between the labels and the returned logits. + + Returns: + + Example: + + ```python + # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch. + from transformers import AutoTokenizer, VisualBertForQuestionAnswering + import torch + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + model = VisualBertForQuestionAnswering.from_pretrained("uclanlp/visualbert-vqa") + + text = "Who is eating the apple?" + inputs = tokenizer(text, return_tensors="pt") + visual_embeds = get_visual_embeddings(image).unsqueeze(0) + visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) + visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float) + + inputs.update( + { + "visual_embeds": visual_embeds, + "visual_token_type_ids": visual_token_type_ids, + "visual_attention_mask": visual_attention_mask, + } + ) + + labels = torch.tensor([[0.0, 1.0]]).unsqueeze(0) # Batch size 1, Num labels 2 + + outputs = model(**inputs, labels=labels) + loss = outputs.loss + scores = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get the index of the last text token + index_to_gather = attention_mask.sum(1) - 2 # as in original code + + outputs = self.visual_bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + visual_embeds=visual_embeds, + visual_attention_mask=visual_attention_mask, + visual_token_type_ids=visual_token_type_ids, + image_text_alignment=image_text_alignment, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # TO-CHECK: From the original code + index_to_gather = ( + index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, sequence_output.size(-1)) + ) + pooled_output = torch.gather(sequence_output, 1, index_to_gather) + + pooled_output = self.dropout(pooled_output) + logits = self.cls(pooled_output) + reshaped_logits = logits.view(-1, self.num_labels) + + loss = None + if labels is not None: + loss_fct = nn.KLDivLoss(reduction="batchmean") + log_softmax = nn.LogSoftmax(dim=-1) + reshaped_logits = log_softmax(reshaped_logits) + loss = loss_fct(reshaped_logits, labels.contiguous()) + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + VisualBert Model with a sequence classification head on top (a dropout and a linear layer on top of the pooled + output) for Visual Reasoning e.g. for NLVR task. + """, + VISUAL_BERT_START_DOCSTRING, +) +class VisualBertForVisualReasoning(VisualBertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.visual_bert = VisualBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.cls = nn.Linear(config.hidden_size, config.num_labels) # 2 + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + visual_embeds: Optional[torch.FloatTensor] = None, + visual_attention_mask: Optional[torch.LongTensor] = None, + visual_token_type_ids: Optional[torch.LongTensor] = None, + image_text_alignment: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. A classification loss is computed (Cross-Entropy) against these labels. + + Returns: + + Example: + + ```python + # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch. + from transformers import AutoTokenizer, VisualBertForVisualReasoning + import torch + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + model = VisualBertForVisualReasoning.from_pretrained("uclanlp/visualbert-nlvr2") + + text = "Who is eating the apple?" + inputs = tokenizer(text, return_tensors="pt") + visual_embeds = get_visual_embeddings(image).unsqueeze(0) + visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) + visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float) + + inputs.update( + { + "visual_embeds": visual_embeds, + "visual_token_type_ids": visual_token_type_ids, + "visual_attention_mask": visual_attention_mask, + } + ) + + labels = torch.tensor(1).unsqueeze(0) # Batch size 1, Num choices 2 + + outputs = model(**inputs, labels=labels) + loss = outputs.loss + scores = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.visual_bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + visual_embeds=visual_embeds, + visual_attention_mask=visual_attention_mask, + visual_token_type_ids=visual_token_type_ids, + image_text_alignment=image_text_alignment, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # sequence_output = outputs[0] + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.cls(pooled_output) + reshaped_logits = logits.contiguous() + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class VisualBertRegionToPhraseAttention(nn.Module): + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = 1 # config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, query, key, attention_mask): + attention_mask = attention_mask.to(query.dtype) + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min + + mixed_query_layer = self.query(query) + mixed_key_layer = self.key(key) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + attention_scores = attention_scores + attention_mask + + attention_scores = attention_scores.squeeze(1) + return attention_scores + + +@add_start_docstrings( + """ + VisualBert Model with a Masked Language Modeling head and an attention layer on top for Region-to-Phrase Alignment + e.g. for Flickr30 Entities task. + """, + VISUAL_BERT_START_DOCSTRING, +) +class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel): + _tied_weights_keys = ["cls.predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.visual_bert = VisualBertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.cls = VisualBertPreTrainingHeads(config) + self.attention = VisualBertRegionToPhraseAttention(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + visual_embeds: Optional[torch.FloatTensor] = None, + visual_attention_mask: Optional[torch.LongTensor] = None, + visual_token_type_ids: Optional[torch.LongTensor] = None, + image_text_alignment: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + region_to_phrase_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + region_to_phrase_position (`torch.LongTensor` of shape `(batch_size, total_sequence_length)`, *optional*): + The positions depicting the position of the image embedding corresponding to the textual tokens. + + labels (`torch.LongTensor` of shape `(batch_size, total_sequence_length, visual_sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. KLDivLoss is computed against these labels and the + outputs from the attention layer. + + Returns: + + Example: + + ```python + # Assumption: *get_visual_embeddings(image)* gets the visual embeddings of the image in the batch. + from transformers import AutoTokenizer, VisualBertForRegionToPhraseAlignment + import torch + + tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") + model = VisualBertForRegionToPhraseAlignment.from_pretrained("uclanlp/visualbert-vqa-coco-pre") + + text = "Who is eating the apple?" + inputs = tokenizer(text, return_tensors="pt") + visual_embeds = get_visual_embeddings(image).unsqueeze(0) + visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) + visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float) + region_to_phrase_position = torch.ones((1, inputs["input_ids"].shape[-1] + visual_embeds.shape[-2])) + + inputs.update( + { + "region_to_phrase_position": region_to_phrase_position, + "visual_embeds": visual_embeds, + "visual_token_type_ids": visual_token_type_ids, + "visual_attention_mask": visual_attention_mask, + } + ) + + labels = torch.ones( + (1, inputs["input_ids"].shape[-1] + visual_embeds.shape[-2], visual_embeds.shape[-2]) + ) # Batch size 1 + + outputs = model(**inputs, labels=labels) + loss = outputs.loss + scores = outputs.logits + ```""" + if region_to_phrase_position is None: + raise ValueError("`region_to_phrase_position` should not be None when using Flickr Model.") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.visual_bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + visual_embeds=visual_embeds, + visual_attention_mask=visual_attention_mask, + visual_token_type_ids=visual_token_type_ids, + image_text_alignment=image_text_alignment, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + region_to_phrase_position_mask = (region_to_phrase_position != -1).long() + + # Make the -1 become 0 + region_to_phrase_position = region_to_phrase_position * region_to_phrase_position_mask + + # Selected_positions = batch x selected position x dim + expanded_region_to_phrase_positions = region_to_phrase_position.unsqueeze(2).expand( + region_to_phrase_position.size(0), region_to_phrase_position.size(1), sequence_output.size(2) + ) + selected_positions = sequence_output.gather(1, expanded_region_to_phrase_positions) + + # Visual Features = batch x visual_feature_length x dim + # This will need separate image and visual masks. + visual_features = sequence_output[:, attention_mask.size(1) :] + + if visual_features.size(1) != visual_attention_mask.size(1): + raise ValueError( + f"Visual features length :{visual_features.size(1)} should be the same" + f" as visual attention mask length: {visual_attention_mask.size(1)}." + ) + + logits = self.attention(selected_positions, visual_features, visual_attention_mask) + + loss = None + + if labels is not None: + # scores = batch x selected position x visual_feature + # scores = selected_positions.bmm(visual_features.transpose(1,2)) + # label = batch x selected_postion x needed position + loss_fct = KLDivLoss(reduction="batchmean") + log_softmax = LogSoftmax(dim=-1) + scores = log_softmax(logits) + labels = labels.contiguous() + loss = loss_fct(scores, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "VisualBertForMultipleChoice", + "VisualBertForPreTraining", + "VisualBertForQuestionAnswering", + "VisualBertForRegionToPhraseAlignment", + "VisualBertForVisualReasoning", + "VisualBertLayer", + "VisualBertModel", + "VisualBertPreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/vit/__init__.py b/docs/transformers/build/lib/transformers/models/vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6a7a23fa63f4c95102584b90d7f775b746ce49 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vit import * + from .feature_extraction_vit import * + from .image_processing_vit import * + from .image_processing_vit_fast import * + from .modeling_flax_vit import * + from .modeling_tf_vit import * + from .modeling_vit import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/vit/configuration_vit.py b/docs/transformers/build/lib/transformers/models/vit/configuration_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..13ad3a7715c5dd411147e33823e0d69e7106005d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit/configuration_vit.py @@ -0,0 +1,151 @@ +# coding=utf-8 +# Copyright 2021 Google AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ViT model configuration""" + +from collections import OrderedDict +from typing import Mapping + +from packaging import version + +from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ViTConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ViTModel`]. It is used to instantiate an ViT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ViT + [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + encoder_stride (`int`, *optional*, defaults to 16): + Factor to increase the spatial resolution by in the decoder head for masked image modeling. + pooler_output_size (`int`, *optional*): + Dimensionality of the pooler layer. If None, defaults to `hidden_size`. + pooler_act (`str`, *optional*, defaults to `"tanh"`): + The activation function to be used by the pooler. Keys of ACT2FN are supported for Flax and + Pytorch, and elements of https://www.tensorflow.org/api_docs/python/tf/keras/activations are + supported for Tensorflow. + + Example: + + ```python + >>> from transformers import ViTConfig, ViTModel + + >>> # Initializing a ViT vit-base-patch16-224 style configuration + >>> configuration = ViTConfig() + + >>> # Initializing a model (with random weights) from the vit-base-patch16-224 style configuration + >>> model = ViTModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vit" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + encoder_stride=16, + pooler_output_size=None, + pooler_act="tanh", + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.encoder_stride = encoder_stride + self.pooler_output_size = pooler_output_size if pooler_output_size else hidden_size + self.pooler_act = pooler_act + + +class ViTOnnxConfig(OnnxConfig): + torch_onnx_minimum_version = version.parse("1.11") + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + return OrderedDict( + [ + ("pixel_values", {0: "batch", 1: "num_channels", 2: "height", 3: "width"}), + ] + ) + + @property + def atol_for_validation(self) -> float: + return 1e-4 + + +__all__ = ["ViTConfig", "ViTOnnxConfig"] diff --git a/docs/transformers/build/lib/transformers/models/vit/convert_dino_to_pytorch.py b/docs/transformers/build/lib/transformers/models/vit/convert_dino_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..8608da8eb411644ab214666560cccae5a9285213 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit/convert_dino_to_pytorch.py @@ -0,0 +1,218 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ViT checkpoints trained with the DINO method.""" + +import argparse +import json +from pathlib import Path + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, base_model=False): + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias")) + + # projection layer + position embeddings + rename_keys.extend( + [ + ("cls_token", "vit.embeddings.cls_token"), + ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"), + ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"), + ("pos_embed", "vit.embeddings.position_embeddings"), + ] + ) + + if base_model: + # layernorm + pooler + rename_keys.extend( + [ + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ] + ) + + # if just the base model, we should remove "vit" from all keys that start with "vit" + rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys] + else: + # layernorm + classification head + rename_keys.extend( + [ + ("norm.weight", "vit.layernorm.weight"), + ("norm.bias", "vit.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, base_model=False): + for i in range(config.num_hidden_layers): + if base_model: + prefix = "" + else: + prefix = "vit." + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +def remove_classification_head_(state_dict): + ignore_keys = ["head.weight", "head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_vit_checkpoint(model_name, pytorch_dump_folder_path, base_model=True): + """ + Copy/paste/tweak model's weights to our ViT structure. + """ + + # define default ViT configuration + config = ViTConfig() + # patch_size + if model_name[-1] == "8": + config.patch_size = 8 + # set labels if required + if not base_model: + config.num_labels = 1000 + repo_id = "huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + # size of the architecture + if model_name in ["dino_vits8", "dino_vits16"]: + config.hidden_size = 384 + config.intermediate_size = 1536 + config.num_hidden_layers = 12 + config.num_attention_heads = 6 + + # load original model from torch hub + original_model = torch.hub.load("facebookresearch/dino:main", model_name) + original_model.eval() + + # load state_dict of original model, remove and rename some keys + state_dict = original_model.state_dict() + if base_model: + remove_classification_head_(state_dict) + rename_keys = create_rename_keys(config, base_model=base_model) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, base_model) + + # load HuggingFace model + if base_model: + model = ViTModel(config, add_pooling_layer=False).eval() + else: + model = ViTForImageClassification(config).eval() + model.load_state_dict(state_dict) + + # Check outputs on an image, prepared by ViTImageProcessor + image_processor = ViTImageProcessor() + encoding = image_processor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + outputs = model(pixel_values) + + if base_model: + final_hidden_state_cls_token = original_model(pixel_values) + assert torch.allclose(final_hidden_state_cls_token, outputs.last_hidden_state[:, 0, :], atol=1e-1) + else: + logits = original_model(pixel_values) + assert logits.shape == outputs.logits.shape + assert torch.allclose(logits, outputs.logits, atol=1e-3) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="dino_vitb16", + type=str, + help="Name of the model trained with DINO you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--base_model", + action="store_true", + help="Whether to only convert the base model (no projection head weights).", + ) + + parser.set_defaults(base_model=True) + args = parser.parse_args() + convert_vit_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.base_model) diff --git a/docs/transformers/build/lib/transformers/models/vit/convert_vit_timm_to_pytorch.py b/docs/transformers/build/lib/transformers/models/vit/convert_vit_timm_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..7892842f8dc18d4eea990acc6a4bac31ff916919 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit/convert_vit_timm_to_pytorch.py @@ -0,0 +1,254 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ViT and non-distilled DeiT checkpoints from the timm library.""" + +import argparse +from pathlib import Path + +import requests +import timm +import torch +from PIL import Image +from timm.data import ImageNetInfo, infer_imagenet_subset + +from transformers import DeiTImageProcessor, ViTConfig, ViTForImageClassification, ViTImageProcessor, ViTModel +from transformers.utils import logging + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, base_model=False): + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append((f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight")) + rename_keys.append((f"blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias")) + rename_keys.append((f"blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias")) + + # projection layer + position embeddings + rename_keys.extend( + [ + ("cls_token", "vit.embeddings.cls_token"), + ("patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"), + ("patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"), + ("pos_embed", "vit.embeddings.position_embeddings"), + ] + ) + + if base_model: + # layernorm + rename_keys.extend( + [ + ("norm.weight", "layernorm.weight"), + ("norm.bias", "layernorm.bias"), + ] + ) + + # if just the base model, we should remove "vit" from all keys that start with "vit" + rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys] + else: + # layernorm + classification head + rename_keys.extend( + [ + ("norm.weight", "vit.layernorm.weight"), + ("norm.bias", "vit.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, base_model=False): + for i in range(config.num_hidden_layers): + if base_model: + prefix = "" + else: + prefix = "vit." + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +def remove_classification_head_(state_dict): + ignore_keys = ["head.weight", "head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +# We will verify our results on an image of cute cats +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + im = Image.open(requests.get(url, stream=True).raw) + return im + + +@torch.no_grad() +def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path): + """ + Copy/paste/tweak model's weights to our ViT structure. + """ + + # define default ViT configuration + config = ViTConfig() + base_model = False + + # load original model from timm + timm_model = timm.create_model(vit_name, pretrained=True) + timm_model.eval() + + # detect unsupported ViT models in transformers + # fc_norm is present + if not isinstance(getattr(timm_model, "fc_norm", None), torch.nn.Identity): + raise ValueError(f"{vit_name} is not supported in transformers because of the presence of fc_norm.") + + # use of global average pooling in combination (or without) class token + if getattr(timm_model, "global_pool", None) == "avg": + raise ValueError(f"{vit_name} is not supported in transformers because of use of global average pooling.") + + # CLIP style vit with norm_pre layer present + if "clip" in vit_name and not isinstance(getattr(timm_model, "norm_pre", None), torch.nn.Identity): + raise ValueError( + f"{vit_name} is not supported in transformers because it's a CLIP style ViT with norm_pre layer." + ) + + # SigLIP style vit with attn_pool layer present + if "siglip" in vit_name and getattr(timm_model, "global_pool", None) == "map": + raise ValueError( + f"{vit_name} is not supported in transformers because it's a SigLIP style ViT with attn_pool." + ) + + # use of layer scale in ViT model blocks + if not isinstance(getattr(timm_model.blocks[0], "ls1", None), torch.nn.Identity) or not isinstance( + getattr(timm_model.blocks[0], "ls2", None), torch.nn.Identity + ): + raise ValueError(f"{vit_name} is not supported in transformers because it uses a layer scale in its blocks.") + + # Hybrid ResNet-ViTs + if not isinstance(timm_model.patch_embed, timm.layers.PatchEmbed): + raise ValueError(f"{vit_name} is not supported in transformers because it is a hybrid ResNet-ViT.") + + # get patch size and image size from the patch embedding submodule + config.patch_size = timm_model.patch_embed.patch_size[0] + config.image_size = timm_model.patch_embed.img_size[0] + + # retrieve architecture-specific parameters from the timm model + config.hidden_size = timm_model.embed_dim + config.intermediate_size = timm_model.blocks[0].mlp.fc1.out_features + config.num_hidden_layers = len(timm_model.blocks) + config.num_attention_heads = timm_model.blocks[0].attn.num_heads + + # check whether the model has a classification head or not + if timm_model.num_classes != 0: + config.num_labels = timm_model.num_classes + # infer ImageNet subset from timm model + imagenet_subset = infer_imagenet_subset(timm_model) + dataset_info = ImageNetInfo(imagenet_subset) + config.id2label = {i: dataset_info.index_to_label_name(i) for i in range(dataset_info.num_classes())} + config.label2id = {v: k for k, v in config.id2label.items()} + else: + print(f"{vit_name} is going to be converted as a feature extractor only.") + base_model = True + + # load state_dict of original model + state_dict = timm_model.state_dict() + + # remove and rename some keys in the state dict + if base_model: + remove_classification_head_(state_dict) + rename_keys = create_rename_keys(config, base_model) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, base_model) + + # load HuggingFace model + if base_model: + model = ViTModel(config, add_pooling_layer=False).eval() + else: + model = ViTForImageClassification(config).eval() + model.load_state_dict(state_dict) + + # Check outputs on an image, prepared by ViTImageProcessor/DeiTImageProcessor + if "deit" in vit_name: + image_processor = DeiTImageProcessor(size=config.image_size) + else: + image_processor = ViTImageProcessor(size=config.image_size) + encoding = image_processor(images=prepare_img(), return_tensors="pt") + pixel_values = encoding["pixel_values"] + outputs = model(pixel_values) + + if base_model: + timm_pooled_output = timm_model.forward_features(pixel_values) + assert timm_pooled_output.shape == outputs.last_hidden_state.shape + assert torch.allclose(timm_pooled_output, outputs.last_hidden_state, atol=1e-1) + else: + timm_logits = timm_model(pixel_values) + assert timm_logits.shape == outputs.logits.shape + assert torch.allclose(timm_logits, outputs.logits, atol=1e-3) + + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + print(f"Saving model {vit_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--vit_name", + default="vit_base_patch16_224", + type=str, + help="Name of the ViT timm model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_vit_checkpoint(args.vit_name, args.pytorch_dump_folder_path) diff --git a/docs/transformers/build/lib/transformers/models/vit/feature_extraction_vit.py b/docs/transformers/build/lib/transformers/models/vit/feature_extraction_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a93df0bd8f82c81a715f78713a27b17672ecf6 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit/feature_extraction_vit.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature extractor class for ViT.""" + +import warnings + +from ...utils import logging +from ...utils.import_utils import requires +from .image_processing_vit import ViTImageProcessor + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class ViTFeatureExtractor(ViTImageProcessor): + def __init__(self, *args, **kwargs) -> None: + warnings.warn( + "The class ViTFeatureExtractor is deprecated and will be removed in version 5 of Transformers. Please" + " use ViTImageProcessor instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + + +__all__ = ["ViTFeatureExtractor"] diff --git a/docs/transformers/build/lib/transformers/models/vit/image_processing_vit.py b/docs/transformers/build/lib/transformers/models/vit/image_processing_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ade7495b1d4a690a0f41319aed6bde3a5078de77 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit/image_processing_vit.py @@ -0,0 +1,288 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for ViT.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +@requires(backends=("vision",)) +class ViTImageProcessor(BaseImageProcessor): + r""" + Constructs a ViT image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `(size["height"], + size["width"])`. Can be overridden by the `do_resize` parameter in the `preprocess` method. + size (`dict`, *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess` + method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_convert_rgb (`bool`, *optional*): + Whether to convert the image to RGB. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[Dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_convert_rgb: Optional[bool] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"height": 224, "width": 224} + size = get_size_dict(size) + self.do_resize = do_resize + self.do_rescale = do_rescale + self.do_normalize = do_normalize + self.size = size + self.resample = resample + self.rescale_factor = rescale_factor + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.do_convert_rgb = do_convert_rgb + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image to `(size["height"], size["width"])`. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BILINEAR`. + data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the output image. If unset, the channel dimension format of the input + image is used. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + + Returns: + `np.ndarray`: The resized image. + """ + size = get_size_dict(size) + if "height" not in size or "width" not in size: + raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}") + output_size = (size["height"], size["width"]) + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + do_convert_rgb: Optional[bool] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after + resizing. + resample (`PILImageResampling` filter, *optional*, defaults to `self.resample`): + `PILImageResampling` filter to use if resizing the image e.g. `PILImageResampling.BILINEAR`. Only has + an effect if `do_resize` is set to `True`. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): + Whether to convert the image to RGB. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + resample = resample if resample is not None else self.resample + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb + + size = size if size is not None else self.size + size_dict = get_size_dict(size) + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if do_convert_rgb: + images = [convert_to_rgb(image) for image in images] + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_resize: + images = [ + self.resize(image=image, size=size_dict, resample=resample, input_data_format=input_data_format) + for image in images + ] + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["ViTImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/vit/image_processing_vit_fast.py b/docs/transformers/build/lib/transformers/models/vit/image_processing_vit_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..61277792cdaa9dda4550221baeb4addae4b8c49e --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit/image_processing_vit_fast.py @@ -0,0 +1,45 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for ViT.""" + +from ...image_processing_utils_fast import ( + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, + BaseImageProcessorFast, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + PILImageResampling, +) +from ...utils import ( + add_start_docstrings, +) + + +@add_start_docstrings( + "Constructs a fast ViT image processor.", + BASE_IMAGE_PROCESSOR_FAST_DOCSTRING, +) +class ViTImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_STANDARD_MEAN + image_std = IMAGENET_STANDARD_STD + size = {"height": 224, "width": 224} + do_resize = True + do_rescale = True + do_normalize = True + + +__all__ = ["ViTImageProcessorFast"] diff --git a/docs/transformers/build/lib/transformers/models/vit/modeling_flax_vit.py b/docs/transformers/build/lib/transformers/models/vit/modeling_flax_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf3477b5ddff66caac2b84ed6ba2f0e00336780 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit/modeling_flax_vit.py @@ -0,0 +1,677 @@ +# coding=utf-8 +# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxBaseModelOutputWithPooling, FlaxSequenceClassifierOutput +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward +from .configuration_vit import ViTConfig + + +VIT_START_DOCSTRING = r""" + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading, saving and converting weights from PyTorch models) + + This model is also a + [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as + a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and + behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`ViTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +VIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxViTPatchEmbeddings(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + image_size = self.config.image_size + patch_size = self.config.patch_size + num_patches = (image_size // patch_size) * (image_size // patch_size) + self.num_patches = num_patches + self.num_channels = self.config.num_channels + self.projection = nn.Conv( + self.config.hidden_size, + kernel_size=(patch_size, patch_size), + strides=(patch_size, patch_size), + padding="VALID", + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + ) + + def __call__(self, pixel_values): + num_channels = pixel_values.shape[-1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embeddings = self.projection(pixel_values) + batch_size, _, _, channels = embeddings.shape + return jnp.reshape(embeddings, (batch_size, -1, channels)) + + +class FlaxViTEmbeddings(nn.Module): + """Construct the CLS token, position and patch embeddings.""" + + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.cls_token = self.param( + "cls_token", + jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), + (1, 1, self.config.hidden_size), + ) + self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = self.param( + "position_embeddings", + jax.nn.initializers.variance_scaling(self.config.initializer_range**2, "fan_in", "truncated_normal"), + (1, num_patches + 1, self.config.hidden_size), + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, pixel_values, deterministic=True): + batch_size = pixel_values.shape[0] + + embeddings = self.patch_embeddings(pixel_values) + + cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) + embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) + embeddings = embeddings + self.position_embeddings + embeddings = self.dropout(embeddings, deterministic=deterministic) + return embeddings + + +class FlaxViTSelfAttention(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`:" + " {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" + ), + use_bias=self.config.qkv_bias, + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" + ), + use_bias=self.config.qkv_bias, + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, mode="fan_in", distribution="truncated_normal" + ), + use_bias=self.config.qkv_bias, + ) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + head_dim = self.config.hidden_size // self.config.num_attention_heads + + query_states = self.query(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + value_states = self.value(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + key_states = self.key(hidden_states).reshape( + hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) + ) + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +class FlaxViTSelfOutput(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxViTAttention(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.attention = FlaxViTSelfAttention(self.config, dtype=self.dtype) + self.output = FlaxViTSelfOutput(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True, output_attentions: bool = False): + attn_outputs = self.attention(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +class FlaxViTIntermediate(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class FlaxViTOutput(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = hidden_states + attention_output + return hidden_states + + +class FlaxViTLayer(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxViTAttention(self.config, dtype=self.dtype) + self.intermediate = FlaxViTIntermediate(self.config, dtype=self.dtype) + self.output = FlaxViTOutput(self.config, dtype=self.dtype) + self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic: bool = True, output_attentions: bool = False): + attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention + deterministic=deterministic, + output_attentions=output_attentions, + ) + + attention_output = attention_outputs[0] + + # first residual connection + attention_output = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(attention_output) + + hidden_states = self.intermediate(layer_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + return outputs + + +class FlaxViTLayerCollection(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxViTLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer(hidden_states, deterministic=deterministic, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states,) + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxViTEncoder(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layer = FlaxViTLayerCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class FlaxViTPooler(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.pooler_output_size, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.pooler_act] + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return self.activation(cls_hidden_state) + + +class FlaxViTPreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + module_class: nn.Module = None + + def __init__( + self, + config: ViTConfig, + input_shape=None, + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + if input_shape is None: + input_shape = (1, config.image_size, config.image_size, config.num_channels) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + pixel_values = jnp.zeros(input_shape, dtype=self.dtype) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + pixel_values, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + return self.module.apply( + {"params": params or self.params}, + jnp.array(pixel_values, dtype=jnp.float32), + not train, + output_attentions, + output_hidden_states, + return_dict, + rngs=rngs, + ) + + +class FlaxViTModule(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + + def setup(self): + self.embeddings = FlaxViTEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxViTEncoder(self.config, dtype=self.dtype) + self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.pooler = FlaxViTPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None + + def __call__( + self, + pixel_values, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embeddings(pixel_values, deterministic=deterministic) + + outputs = self.encoder( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + hidden_states = self.layernorm(hidden_states) + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.", + VIT_START_DOCSTRING, +) +class FlaxViTModel(FlaxViTPreTrainedModel): + module_class = FlaxViTModule + + +FLAX_VISION_MODEL_DOCSTRING = """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, FlaxViTModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") + >>> model = FlaxViTModel.from_pretrained("google/vit-base-patch16-224-in21k") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ``` +""" + +overwrite_call_docstring(FlaxViTModel, FLAX_VISION_MODEL_DOCSTRING) +append_replace_return_docstrings(FlaxViTModel, output_type=FlaxBaseModelOutputWithPooling, config_class=ViTConfig) + + +class FlaxViTForImageClassificationModule(nn.Module): + config: ViTConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.vit = FlaxViTModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.classifier = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + kernel_init=jax.nn.initializers.variance_scaling( + self.config.initializer_range**2, "fan_in", "truncated_normal" + ), + ) + + def __call__( + self, + pixel_values=None, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vit( + pixel_values, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.classifier(hidden_states[:, 0, :]) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + """, + VIT_START_DOCSTRING, +) +class FlaxViTForImageClassification(FlaxViTPreTrainedModel): + module_class = FlaxViTForImageClassificationModule + + +FLAX_VISION_CLASSIF_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoImageProcessor, FlaxViTForImageClassification + >>> from PIL import Image + >>> import jax + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") + >>> model = FlaxViTForImageClassification.from_pretrained("google/vit-base-patch16-224") + + >>> inputs = image_processor(images=image, return_tensors="np") + >>> outputs = model(**inputs) + >>> logits = outputs.logits + + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) + >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) + ``` +""" + +overwrite_call_docstring(FlaxViTForImageClassification, FLAX_VISION_CLASSIF_DOCSTRING) +append_replace_return_docstrings( + FlaxViTForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=ViTConfig +) + + +__all__ = ["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/vit/modeling_tf_vit.py b/docs/transformers/build/lib/transformers/models/vit/modeling_tf_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..e18b38e597f37298a75570a90f7c1019bf7c7788 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit/modeling_tf_vit.py @@ -0,0 +1,907 @@ +# coding=utf-8 +# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 ViT model.""" + +from __future__ import annotations + +import collections.abc +import math +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput +from ...modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + TFSequenceClassificationLoss, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_vit import ViTConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "ViTConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k" +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" + + +class TFViTEmbeddings(keras.layers.Layer): + """ + Construct the CLS token, position and patch embeddings. + + """ + + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.patch_embeddings = TFViTPatchEmbeddings(config, name="patch_embeddings") + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def build(self, input_shape=None): + num_patches = self.patch_embeddings.num_patches + self.cls_token = self.add_weight( + shape=(1, 1, self.config.hidden_size), + initializer=get_initializer(self.config.initializer_range), + trainable=True, + name="cls_token", + ) + self.position_embeddings = self.add_weight( + shape=(1, num_patches + 1, self.config.hidden_size), + initializer=get_initializer(self.config.initializer_range), + trainable=True, + name="position_embeddings", + ) + + if self.built: + return + self.built = True + if getattr(self, "patch_embeddings", None) is not None: + with tf.name_scope(self.patch_embeddings.name): + self.patch_embeddings.build(None) + + def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + batch_size, seq_len, dim = shape_list(embeddings) + num_patches = seq_len - 1 + + _, num_positions, _ = shape_list(self.position_embeddings) + num_positions -= 1 + + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + patch_pos_embed = tf.image.resize( + images=tf.reshape( + patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + ), + size=(h0, w0), + method="bicubic", + ) + + shape = shape_list(patch_pos_embed) + assert h0 == shape[-3] and w0 == shape[-2] + patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) + return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) + + def call( + self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False + ) -> tf.Tensor: + batch_size, num_channels, height, width = shape_list(pixel_values) + embeddings = self.patch_embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, training=training + ) + + # add the [CLS] token to the embedded patch tokens + cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0) + embeddings = tf.concat((cls_tokens, embeddings), axis=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings, training=training) + + return embeddings + + +# Based on timm implementation, which can be found here: +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +class TFViTPatchEmbeddings(keras.layers.Layer): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + self.num_channels = num_channels + self.config = config + + self.projection = keras.layers.Conv2D( + filters=hidden_size, + kernel_size=patch_size, + strides=patch_size, + padding="valid", + data_format="channels_last", + use_bias=True, + kernel_initializer=get_initializer(self.config.initializer_range), + bias_initializer="zeros", + name="projection", + ) + + def call( + self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False + ) -> tf.Tensor: + batch_size, num_channels, height, width = shape_list(pixel_values) + if tf.executing_eagerly() and num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + if not interpolate_pos_encoding: + if tf.executing_eagerly(): + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + + # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + + projection = self.projection(pixel_values) + + # Change the 2D spatial dimensions to a single temporal dimension. + # shape = (batch_size, num_patches, out_channels=embed_dim) + num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) + embeddings = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1)) + + return embeddings + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, None, self.num_channels]) + + +class TFViTSelfAttention(keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + mixed_key_layer = self.key(inputs=hidden_states) + mixed_value_layer = self.value(inputs=hidden_states) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) + value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +class TFViTSelfOutput(keras.layers.Layer): + """ + The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFViTAttention(keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFViTSelfAttention(config, name="attention") + self.dense_output = TFViTSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +class TFViTIntermediate(keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +class TFViTOutput(keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = hidden_states + input_tensor + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + + +class TFViTLayer(keras.layers.Layer): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFViTAttention(config, name="attention") + self.intermediate = TFViTIntermediate(config, name="intermediate") + self.vit_output = TFViTOutput(config, name="output") + + self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") + self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + attention_outputs = self.attention( + # in ViT, layernorm is applied before self-attention + input_tensor=self.layernorm_before(inputs=hidden_states), + head_mask=head_mask, + output_attentions=output_attentions, + training=training, + ) + attention_output = attention_outputs[0] + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(inputs=hidden_states) + + intermediate_output = self.intermediate(hidden_states=layer_output) + + # second residual connection is done here + layer_output = self.vit_output( + hidden_states=intermediate_output, input_tensor=hidden_states, training=training + ) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "vit_output", None) is not None: + with tf.name_scope(self.vit_output.name): + self.vit_output.build(None) + if getattr(self, "layernorm_before", None) is not None: + with tf.name_scope(self.layernorm_before.name): + self.layernorm_before.build([None, None, self.config.hidden_size]) + if getattr(self, "layernorm_after", None) is not None: + with tf.name_scope(self.layernorm_after.name): + self.layernorm_after.build([None, None, self.config.hidden_size]) + + +class TFViTEncoder(keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.layer = [TFViTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states=hidden_states, + head_mask=head_mask[i], + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFViTMainLayer(keras.layers.Layer): + config_class = ViTConfig + + def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, **kwargs): + super().__init__(**kwargs) + + self.config = config + + self.embeddings = TFViTEmbeddings(config, name="embeddings") + self.encoder = TFViTEncoder(config, name="encoder") + self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + self.pooler = TFViTPooler(config, name="pooler") if add_pooling_layer else None + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + pixel_values: TFModelInputType | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + embedding_output = self.embeddings( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + training=training, + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(inputs=sequence_output) + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return TFBaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "layernorm", None) is not None: + with tf.name_scope(self.layernorm.name): + self.layernorm.build([None, None, self.config.hidden_size]) + if getattr(self, "pooler", None) is not None: + with tf.name_scope(self.pooler.name): + self.pooler.build(None) + + +class TFViTPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + + +VIT_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`ViTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.", + VIT_START_DOCSTRING, +) +class TFViTModel(TFViTPreTrainedModel): + def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit") + + @unpack_inputs + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def call( + self, + pixel_values: TFModelInputType | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]: + outputs = self.vit( + pixel_values=pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "vit", None) is not None: + with tf.name_scope(self.vit.name): + self.vit.build(None) + + +class TFViTPooler(keras.layers.Layer): + def __init__(self, config: ViTConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.pooler_output_size, + kernel_initializer=get_initializer(config.initializer_range), + activation=config.pooler_act, + name="dense", + ) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +@add_start_docstrings( + """ + ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, + VIT_START_DOCSTRING, +) +class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassificationLoss): + def __init__(self, config: ViTConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.num_labels = config.num_labels + self.vit = TFViTMainLayer(config, add_pooling_layer=False, name="vit") + + # Classifier head + self.classifier = keras.layers.Dense( + units=config.num_labels, + kernel_initializer=get_initializer(config.initializer_range), + name="classifier", + ) + self.config = config + + @unpack_inputs + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def call( + self, + pixel_values: TFModelInputType | None = None, + head_mask: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: np.ndarray | tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + outputs = self.vit( + pixel_values=pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.classifier(inputs=sequence_output[:, 0, :]) + loss = None if labels is None else self.hf_compute_loss(labels=labels, logits=logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "vit", None) is not None: + with tf.name_scope(self.vit.name): + self.vit.build(None) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.hidden_size]) + + +__all__ = ["TFViTForImageClassification", "TFViTModel", "TFViTPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/vit/modeling_vit.py b/docs/transformers/build/lib/transformers/models/vit/modeling_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..d757aeaf28b1ac0a32f042d98ef4bf929fe7f910 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit/modeling_vit.py @@ -0,0 +1,883 @@ +# coding=utf-8 +# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ViT model.""" + +import collections.abc +import math +from typing import Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + ImageClassifierOutput, + MaskedImageModelingOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_vit import ViTConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "ViTConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k" +_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] + +# Image classification docstring +_IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224" +_IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" + + +class ViTEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = ViTPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class ViTPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class ViTSelfAttention(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class ViTSelfOutput(nn.Module): + """ + The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class ViTAttention(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.attention = ViTSelfAttention(config) + self.output = ViTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ViTIntermediate(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class ViTOutput(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +class ViTLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ViTAttention(config) + self.intermediate = ViTIntermediate(config) + self.output = ViTOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class ViTEncoder(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ViTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["ViTEmbeddings", "ViTLayer"] + _supports_sdpa = True + _supports_flash_attn_2 = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ViTEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + if module.mask_token is not None: + module.mask_token.data.zero_() + + +VIT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ViTConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.", + VIT_START_DOCSTRING, +) +class ViTModel(ViTPreTrainedModel): + def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): + super().__init__(config) + self.config = config + + self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = ViTEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = ViTPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ViTPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPooling, + config_class=_CONFIG_FOR_DOC, + modality="vision", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class ViTPooler(nn.Module): + def __init__(self, config: ViTConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.pooler_output_size) + self.activation = ACT2FN[config.pooler_act] + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@add_start_docstrings( + """ViT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886). + + + + Note that we provide a script to pre-train this model on custom data in our [examples + directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining). + + + """, + VIT_START_DOCSTRING, +) +class ViTForMaskedImageModeling(ViTPreTrainedModel): + def __init__(self, config: ViTConfig) -> None: + super().__init__(config) + + self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True) + + self.decoder = nn.Sequential( + nn.Conv2d( + in_channels=config.hidden_size, + out_channels=config.encoder_stride**2 * config.num_channels, + kernel_size=1, + ), + nn.PixelShuffle(config.encoder_stride), + ) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, MaskedImageModelingOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") + >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction + >>> list(reconstructed_pixel_values.shape) + [1, 3, 224, 224] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input. " + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}." + ) + + outputs = self.vit( + pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + # Reshape to (batch_size, num_channels, height, width) + sequence_output = sequence_output[:, 1:] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = ( + bool_masked_pos.repeat_interleave(self.config.patch_size, 1) + .repeat_interleave(self.config.patch_size, 2) + .unsqueeze(1) + .contiguous() + ) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of + the [CLS] token) e.g. for ImageNet. + + + + Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, + VIT_START_DOCSTRING, +) +class ViTForImageClassification(ViTPreTrainedModel): + def __init__(self, config: ViTConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.vit = ViTModel(config, add_pooling_layer=False) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_IMAGE_CLASS_CHECKPOINT, + output_type=ImageClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, + ) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["ViTForImageClassification", "ViTForMaskedImageModeling", "ViTModel", "ViTPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/vit_mae/__init__.py b/docs/transformers/build/lib/transformers/models/vit_mae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..253017c39d6a2bb0e835d83c23526c8d7b08665a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit_mae/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vit_mae import * + from .modeling_tf_vit_mae import * + from .modeling_vit_mae import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/vit_mae/configuration_vit_mae.py b/docs/transformers/build/lib/transformers/models/vit_mae/configuration_vit_mae.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5ec3600599e6ad3f2708880ce35a9ec9a626e1 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit_mae/configuration_vit_mae.py @@ -0,0 +1,140 @@ +# coding=utf-8 +# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ViT MAE model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ViTMAEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ViTMAEModel`]. It is used to instantiate an ViT + MAE model according to the specified arguments, defining the model architecture. Instantiating a configuration with + the defaults will yield a similar configuration to that of the ViT + [facebook/vit-mae-base](https://huggingface.co/facebook/vit-mae-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + decoder_num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the decoder. + decoder_hidden_size (`int`, *optional*, defaults to 512): + Dimensionality of the decoder. + decoder_num_hidden_layers (`int`, *optional*, defaults to 8): + Number of hidden layers in the decoder. + decoder_intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the decoder. + mask_ratio (`float`, *optional*, defaults to 0.75): + The ratio of the number of masked tokens in the input sequence. + norm_pix_loss (`bool`, *optional*, defaults to `False`): + Whether or not to train with normalized pixels (see Table 3 in the paper). Using normalized pixels improved + representation quality in the experiments of the authors. + + Example: + + ```python + >>> from transformers import ViTMAEConfig, ViTMAEModel + + >>> # Initializing a ViT MAE vit-mae-base style configuration + >>> configuration = ViTMAEConfig() + + >>> # Initializing a model (with random weights) from the vit-mae-base style configuration + >>> model = ViTMAEModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vit_mae" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + decoder_num_attention_heads=16, + decoder_hidden_size=512, + decoder_num_hidden_layers=8, + decoder_intermediate_size=2048, + mask_ratio=0.75, + norm_pix_loss=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.decoder_num_attention_heads = decoder_num_attention_heads + self.decoder_hidden_size = decoder_hidden_size + self.decoder_num_hidden_layers = decoder_num_hidden_layers + self.decoder_intermediate_size = decoder_intermediate_size + self.mask_ratio = mask_ratio + self.norm_pix_loss = norm_pix_loss + + +__all__ = ["ViTMAEConfig"] diff --git a/docs/transformers/build/lib/transformers/models/vit_mae/convert_vit_mae_to_pytorch.py b/docs/transformers/build/lib/transformers/models/vit_mae/convert_vit_mae_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..47e77593f6fd3ad7c2b7ff2c329b84f432060c7d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit_mae/convert_vit_mae_to_pytorch.py @@ -0,0 +1,178 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ViT MAE checkpoints from the original repository: https://github.com/facebookresearch/mae""" + +import argparse + +import requests +import torch +from PIL import Image + +from transformers import ViTMAEConfig, ViTMAEForPreTraining, ViTMAEImageProcessor + + +def rename_key(name): + if "cls_token" in name: + name = name.replace("cls_token", "vit.embeddings.cls_token") + if "mask_token" in name: + name = name.replace("mask_token", "decoder.mask_token") + if "decoder_pos_embed" in name: + name = name.replace("decoder_pos_embed", "decoder.decoder_pos_embed") + if "pos_embed" in name and "decoder" not in name: + name = name.replace("pos_embed", "vit.embeddings.position_embeddings") + if "patch_embed.proj" in name: + name = name.replace("patch_embed.proj", "vit.embeddings.patch_embeddings.projection") + if "patch_embed.norm" in name: + name = name.replace("patch_embed.norm", "vit.embeddings.norm") + if "decoder_blocks" in name: + name = name.replace("decoder_blocks", "decoder.decoder_layers") + if "blocks" in name: + name = name.replace("blocks", "vit.encoder.layer") + if "attn.proj" in name: + name = name.replace("attn.proj", "attention.output.dense") + if "attn" in name: + name = name.replace("attn", "attention.self") + if "norm1" in name: + name = name.replace("norm1", "layernorm_before") + if "norm2" in name: + name = name.replace("norm2", "layernorm_after") + if "mlp.fc1" in name: + name = name.replace("mlp.fc1", "intermediate.dense") + if "mlp.fc2" in name: + name = name.replace("mlp.fc2", "output.dense") + if "decoder_embed" in name: + name = name.replace("decoder_embed", "decoder.decoder_embed") + if "decoder_norm" in name: + name = name.replace("decoder_norm", "decoder.decoder_norm") + if "decoder_pred" in name: + name = name.replace("decoder_pred", "decoder.decoder_pred") + if "norm.weight" in name and "decoder" not in name: + name = name.replace("norm.weight", "vit.layernorm.weight") + if "norm.bias" in name and "decoder" not in name: + name = name.replace("norm.bias", "vit.layernorm.bias") + + return name + + +def convert_state_dict(orig_state_dict, config): + for key in orig_state_dict.copy().keys(): + val = orig_state_dict.pop(key) + + if "qkv" in key: + key_split = key.split(".") + layer_num = int(key_split[1]) + if "decoder_blocks" in key: + dim = config.decoder_hidden_size + prefix = "decoder.decoder_layers." + if "weight" in key: + orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :] + orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :] + orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :] + elif "bias" in key: + orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.bias"] = val[:dim] + orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2] + orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.bias"] = val[-dim:] + else: + dim = config.hidden_size + prefix = "vit.encoder.layer." + if "weight" in key: + orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.weight"] = val[:dim, :] + orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.weight"] = val[dim : dim * 2, :] + orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.weight"] = val[-dim:, :] + elif "bias" in key: + orig_state_dict[f"{prefix}{layer_num}.attention.attention.query.bias"] = val[:dim] + orig_state_dict[f"{prefix}{layer_num}.attention.attention.key.bias"] = val[dim : dim * 2] + orig_state_dict[f"{prefix}{layer_num}.attention.attention.value.bias"] = val[-dim:] + + else: + orig_state_dict[rename_key(key)] = val + + return orig_state_dict + + +def convert_vit_mae_checkpoint(checkpoint_url, pytorch_dump_folder_path): + config = ViTMAEConfig() + if "large" in checkpoint_url: + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + elif "huge" in checkpoint_url: + config.patch_size = 14 + config.hidden_size = 1280 + config.intermediate_size = 5120 + config.num_hidden_layers = 32 + config.num_attention_heads = 16 + + model = ViTMAEForPreTraining(config) + + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["model"] + + image_processor = ViTMAEImageProcessor(size=config.image_size) + + new_state_dict = convert_state_dict(state_dict, config) + + model.load_state_dict(new_state_dict) + model.eval() + + url = "https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg" + + image = Image.open(requests.get(url, stream=True).raw) + image_processor = ViTMAEImageProcessor(size=config.image_size) + inputs = image_processor(images=image, return_tensors="pt") + + # forward pass + torch.manual_seed(2) + outputs = model(**inputs) + logits = outputs.logits + + if "large" in checkpoint_url: + expected_slice = torch.tensor( + [[-0.7309, -0.7128, -1.0169], [-1.0161, -0.9058, -1.1878], [-1.0478, -0.9411, -1.1911]] + ) + elif "huge" in checkpoint_url: + expected_slice = torch.tensor( + [[-1.1599, -0.9199, -1.2221], [-1.1952, -0.9269, -1.2307], [-1.2143, -0.9337, -1.2262]] + ) + else: + expected_slice = torch.tensor( + [[-0.9192, -0.8481, -1.1259], [-1.1349, -1.0034, -1.2599], [-1.1757, -1.0429, -1.2726]] + ) + + # verify logits + assert torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4) + + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default="https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth", + type=str, + help="URL of the checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_vit_mae_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path) diff --git a/docs/transformers/build/lib/transformers/models/vit_mae/modeling_tf_vit_mae.py b/docs/transformers/build/lib/transformers/models/vit_mae/modeling_tf_vit_mae.py new file mode 100644 index 0000000000000000000000000000000000000000..8879a8665f3ed6cf1009d4b7c2366807e7c51664 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit_mae/modeling_tf_vit_mae.py @@ -0,0 +1,1375 @@ +# coding=utf-8 +# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF 2.0 ViT MAE (masked autoencoder) model.""" + +from __future__ import annotations + +import collections.abc +import math +from copy import deepcopy +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...file_utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...modeling_tf_outputs import TFBaseModelOutput +from ...modeling_tf_utils import ( + TFModelInputType, + TFPreTrainedModel, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import logging +from .configuration_vit_mae import ViTMAEConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ViTMAEConfig" +_CHECKPOINT_FOR_DOC = "facebook/vit-mae-base" + + +@dataclass +class TFViTMAEModelOutput(ModelOutput): + """ + Class for TFViTMAEModel's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Tensor indicating which patches are masked (1) and which are not (0). + ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Tensor containing the original index of the (shuffled) masked patches. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + last_hidden_state: Optional[tf.Tensor] = None + mask: Optional[tf.Tensor] = None + ids_restore: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFViTMAEDecoderOutput(ModelOutput): + """ + Class for TFViTMAEDecoder's outputs, with potential hidden states and attentions. + + Args: + logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`): + Pixel reconstruction logits. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + logits: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +@dataclass +class TFViTMAEForPreTrainingOutput(ModelOutput): + """ + Class for TFViTMAEForPreTraining's outputs, with potential hidden states and attentions. + + Args: + loss (`tf.Tensor` of shape `(1,)`): + Pixel reconstruction loss. + logits (`tf.Tensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`): + Pixel reconstruction logits. + mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Tensor indicating which patches are masked (1) and which are not (0). + ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Tensor containing the original index of the (shuffled) masked patches. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus + the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + loss: tf.Tensor | None = None + logits: Optional[tf.Tensor] = None + mask: Optional[tf.Tensor] = None + ids_restore: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): + """ + Create 2D sin/cos positional embeddings. + + Args: + embed_dim (`int`): + Embedding dimension. + grid_size (`int`): + The grid height and width. + add_cls_token (`bool`, *optional*, defaults to `False`): + Whether or not to add a classification (CLS) token. + + Returns: + (`tf.Tensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the position + embeddings (with or without classification token) + """ + grid_h = tf.range(grid_size, dtype=tf.float32) + grid_w = tf.range(grid_size, dtype=tf.float32) + grid = tf.meshgrid(grid_w, grid_h) # here w goes first + grid = tf.stack(grid, axis=0) + + grid = tf.reshape(grid, [2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if add_cls_token: + pos_embed = tf.concat([tf.zeros((1, embed_dim)), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = tf.concat([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + + omega = tf.range(embed_dim // 2, dtype="float32") + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = tf.reshape(pos, [-1]) # (M,) + out = tf.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + # half of the positions get sinusoidal pattern and the rest gets + # cosine pattern and then they are concatenated + emb_sin = tf.sin(out) # (M, D/2) + emb_cos = tf.cos(out) # (M, D/2) + + emb = tf.concat([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class TFViTMAEEmbeddings(keras.layers.Layer): + """ + Construct the CLS token, position and patch embeddings. + + """ + + def __init__(self, config: ViTMAEConfig, **kwargs): + super().__init__(**kwargs) + + self.patch_embeddings = TFViTMAEPatchEmbeddings(config, name="patch_embeddings") + self.num_patches = self.patch_embeddings.num_patches + + self.config = config + + def build(self, input_shape=None): + self.cls_token = self.add_weight( + shape=(1, 1, self.config.hidden_size), + initializer=tf.random_normal_initializer(stddev=self.config.initializer_range), + trainable=True, + name="cls_token", + ) + self.position_embeddings = self.add_weight( + shape=(1, self.num_patches + 1, self.config.hidden_size), + initializer="zeros", + trainable=False, # fixed sin-cos embedding + name="position_embeddings", + ) + pos_embed = get_2d_sincos_pos_embed( + self.position_embeddings.shape[-1], + int(self.patch_embeddings.num_patches**0.5), + add_cls_token=True, + )[None, ...] + self.position_embeddings.assign(pos_embed) + + if self.built: + return + self.built = True + if getattr(self, "patch_embeddings", None) is not None: + with tf.name_scope(self.patch_embeddings.name): + self.patch_embeddings.build(None) + + def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + batch_size, seq_len, dim = shape_list(embeddings) + num_patches = seq_len - 1 + + _, num_positions, _ = shape_list(self.position_embeddings) + num_positions -= 1 + + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + patch_pos_embed = tf.image.resize( + images=tf.reshape( + patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + ), + size=(h0, w0), + method="bicubic", + ) + + patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) + return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) + + def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None): + """ + Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random + noise. + + Args: + sequence (`tf.Tensor` of shape `(batch_size, sequence_length, dim)`) + noise (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*) which is + mainly used for testing purposes to control randomness and maintain the reproducibility + """ + batch_size, seq_length, dim = shape_list(sequence) + len_keep = int(seq_length * (1 - self.config.mask_ratio)) + + if noise is None: + noise = tf.random.uniform(shape=(batch_size, seq_length), minval=0.0, maxval=1.0) # noise in [0, 1) + + # sort noise for each sample + ids_shuffle = tf.argsort(noise, axis=1) # ascend: small is keep, large is remove + ids_restore = tf.argsort(ids_shuffle, axis=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + sequence_unmasked = tf.gather( + sequence, + axis=1, + batch_dims=1, + indices=ids_keep, + ) + + # generate the binary mask: 0 is keep, 1 is remove + # this hack is needed because TF's EagerTensors don't support + # assignment + mask_keep = tf.zeros((batch_size, len_keep)) + mask_remove = tf.ones((batch_size, seq_length - len_keep)) + mask = tf.concat([mask_keep, mask_remove], axis=-1) + + # unshuffle to get the binary mask + mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore) + + return sequence_unmasked, mask, ids_restore + + def call( + self, pixel_values: tf.Tensor, noise: Optional[tf.Tensor] = None, interpolate_pos_encoding: bool = False + ) -> tf.Tensor: + batch_size, num_channels, height, width = shape_list(pixel_values) + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + if interpolate_pos_encoding: + position_embeddings = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embeddings = self.position_embeddings + # add position embeddings w/o cls token + embeddings = embeddings + position_embeddings[:, 1:, :] + + # masking: length -> length * config.mask_ratio + embeddings, mask, ids_restore = self.random_masking(embeddings, noise) + + # append cls token + cls_token = self.cls_token + position_embeddings[:, :1, :] + cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1)) + embeddings = tf.concat([cls_tokens, embeddings], axis=1) + + return embeddings, mask, ids_restore + + +class TFViTMAEPatchEmbeddings(keras.layers.Layer): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config: ViTMAEConfig, **kwargs): + super().__init__(**kwargs) + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + self.num_channels = num_channels + self.config = config + + self.projection = keras.layers.Conv2D( + filters=hidden_size, + kernel_size=patch_size, + strides=patch_size, + padding="valid", + data_format="channels_last", + kernel_initializer="glorot_uniform", # following torch.nn.Linear + bias_initializer="zeros", + name="projection", + ) + + def call( + self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False + ) -> tf.Tensor: + batch_size, num_channels, height, width = shape_list(pixel_values) + if tf.executing_eagerly(): + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the" + " configuration." + ) + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + + # When running on CPU, `keras.layers.Conv2D` doesn't support `NCHW` format. + # So change the input format from `NCHW` to `NHWC`. + # shape = (batch_size, in_height, in_width, in_channels=num_channels) + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + + projection = self.projection(pixel_values) + + # Change the 2D spatial dimensions to a single temporal dimension. + # shape = (batch_size, num_patches, out_channels=embed_dim) + num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0]) + x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1)) + + return x + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, None, self.num_channels]) + + +# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->ViTMAE +class TFViTMAESelfAttention(keras.layers.Layer): + def __init__(self, config: ViTMAEConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + self.config = config + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + mixed_key_layer = self.key(inputs=hidden_states) + mixed_value_layer = self.value(inputs=hidden_states) + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + key_layer = self.transpose_for_scores(mixed_key_layer, batch_size) + value_layer = self.transpose_for_scores(mixed_value_layer, batch_size) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "query", None) is not None: + with tf.name_scope(self.query.name): + self.query.build([None, None, self.config.hidden_size]) + if getattr(self, "key", None) is not None: + with tf.name_scope(self.key.name): + self.key.build([None, None, self.config.hidden_size]) + if getattr(self, "value", None) is not None: + with tf.name_scope(self.value.name): + self.value.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->ViTMAE +class TFViTMAESelfOutput(keras.layers.Layer): + """ + The residual connection is defined in TFViTMAELayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTMAEConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->ViTMAE +class TFViTMAEAttention(keras.layers.Layer): + def __init__(self, config: ViTMAEConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFViTMAESelfAttention(config, name="attention") + self.dense_output = TFViTMAESelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "self_attention", None) is not None: + with tf.name_scope(self.self_attention.name): + self.self_attention.build(None) + if getattr(self, "dense_output", None) is not None: + with tf.name_scope(self.dense_output.name): + self.dense_output.build(None) + + +# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->ViTMAE +class TFViTMAEIntermediate(keras.layers.Layer): + def __init__(self, config: ViTMAEConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->ViTMAE +class TFViTMAEOutput(keras.layers.Layer): + def __init__(self, config: ViTMAEConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.dropout = keras.layers.Dropout(rate=config.hidden_dropout_prob) + self.config = config + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = hidden_states + input_tensor + + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "dense", None) is not None: + with tf.name_scope(self.dense.name): + self.dense.build([None, None, self.config.intermediate_size]) + + +# Copied from transformers.models.vit.modeling_tf_vit.TFViTLayer with ViT->ViTMAE +class TFViTMAELayer(keras.layers.Layer): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTMAEConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFViTMAEAttention(config, name="attention") + self.intermediate = TFViTMAEIntermediate(config, name="intermediate") + self.vit_output = TFViTMAEOutput(config, name="output") + + self.layernorm_before = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_before") + self.layernorm_after = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm_after") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + attention_outputs = self.attention( + # in ViTMAE, layernorm is applied before self-attention + input_tensor=self.layernorm_before(inputs=hidden_states), + head_mask=head_mask, + output_attentions=output_attentions, + training=training, + ) + attention_output = attention_outputs[0] + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViTMAE, layernorm is also applied after self-attention + layer_output = self.layernorm_after(inputs=hidden_states) + + intermediate_output = self.intermediate(hidden_states=layer_output) + + # second residual connection is done here + layer_output = self.vit_output( + hidden_states=intermediate_output, input_tensor=hidden_states, training=training + ) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "intermediate", None) is not None: + with tf.name_scope(self.intermediate.name): + self.intermediate.build(None) + if getattr(self, "vit_output", None) is not None: + with tf.name_scope(self.vit_output.name): + self.vit_output.build(None) + if getattr(self, "layernorm_before", None) is not None: + with tf.name_scope(self.layernorm_before.name): + self.layernorm_before.build([None, None, self.config.hidden_size]) + if getattr(self, "layernorm_after", None) is not None: + with tf.name_scope(self.layernorm_after.name): + self.layernorm_after.build([None, None, self.config.hidden_size]) + + +# Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->ViTMAE +class TFViTMAEEncoder(keras.layers.Layer): + def __init__(self, config: ViTMAEConfig, **kwargs): + super().__init__(**kwargs) + + self.layer = [TFViTMAELayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + head_mask: tf.Tensor, + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states=hidden_states, + head_mask=head_mask[i], + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + + return TFBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFViTMAEMainLayer(keras.layers.Layer): + config_class = ViTMAEConfig + + def __init__(self, config: ViTMAEConfig, **kwargs): + super().__init__(**kwargs) + + self.config = config + + self.embeddings = TFViTMAEEmbeddings(config, name="embeddings") + self.encoder = TFViTMAEEncoder(config, name="encoder") + self.layernorm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm") + + def get_input_embeddings(self) -> keras.layers.Layer: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + def call( + self, + pixel_values: TFModelInputType | None = None, + noise: Optional[tf.Tensor] = None, + head_mask: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + interpolate_pos_encoding: bool = False, + ) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]: + embedding_output, mask, ids_restore = self.embeddings( + pixel_values=pixel_values, + training=training, + noise=noise, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(inputs=sequence_output) + + if not return_dict: + return (sequence_output, mask, ids_restore) + encoder_outputs[1:] + + return TFViTMAEModelOutput( + last_hidden_state=sequence_output, + mask=mask, + ids_restore=ids_restore, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "embeddings", None) is not None: + with tf.name_scope(self.embeddings.name): + self.embeddings.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + if getattr(self, "layernorm", None) is not None: + with tf.name_scope(self.layernorm.name): + self.layernorm.build([None, None, self.config.hidden_size]) + + +class TFViTMAEPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTMAEConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + + +VIT_MAE_START_DOCSTRING = r""" + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `pixel_values` only and nothing else: `model(pixel_values)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([pixel_values, attention_mask])` or `model([pixel_values, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"pixel_values": pixel_values, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`ViTMAEConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VIT_MAE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used + in eager mode, in graph mode the value will always be set to True. + + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). + + interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): + Whether to interpolate the position encodings at the encoder and decoder. +""" + + +@add_start_docstrings( + "The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.", + VIT_MAE_START_DOCSTRING, +) +class TFViTMAEModel(TFViTMAEPreTrainedModel): + def __init__(self, config: ViTMAEConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.vit = TFViTMAEMainLayer(config, name="vit") + + def get_input_embeddings(self): + return self.vit.get_input_embeddings() + + @unpack_inputs + @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFViTMAEModelOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: TFModelInputType | None = None, + noise: Optional[tf.Tensor] = None, + head_mask: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + interpolate_pos_encoding: bool = False, + ) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, TFViTMAEModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base") + >>> model = TFViTMAEModel.from_pretrained("facebook/vit-mae-base") + + >>> inputs = image_processor(images=image, return_tensors="tf") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + outputs = self.vit( + pixel_values=pixel_values, + noise=noise, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "vit", None) is not None: + with tf.name_scope(self.vit.name): + self.vit.build(None) + + +class TFViTMAEDecoder(keras.layers.Layer): + def __init__(self, config, num_patches, **kwargs): + super().__init__(**kwargs) + self.decoder_embed = keras.layers.Dense(config.decoder_hidden_size, name="decoder_embed") + + decoder_config = deepcopy(config) + decoder_config.hidden_size = config.decoder_hidden_size + decoder_config.num_hidden_layers = config.decoder_num_hidden_layers + decoder_config.num_attention_heads = config.decoder_num_attention_heads + decoder_config.intermediate_size = config.decoder_intermediate_size + self.decoder_layers = [ + TFViTMAELayer(decoder_config, name=f"decoder_layers.{j}") for j in range(config.decoder_num_hidden_layers) + ] + + self.decoder_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="decoder_norm") + self.decoder_pred = keras.layers.Dense( + config.patch_size**2 * config.num_channels, + kernel_initializer=get_initializer(config.initializer_range), + name="decoder_pred", + ) # encoder to decoder + self.config = config + self.num_patches = num_patches + + def build(self, input_shape=None): + self.mask_token = self.add_weight( + shape=(1, 1, self.config.decoder_hidden_size), + initializer=tf.random_normal_initializer(stddev=self.config.initializer_range), + trainable=True, + name="mask_token", + ) + self.decoder_pos_embed = self.add_weight( + shape=(1, self.num_patches + 1, self.config.decoder_hidden_size), + initializer="zeros", + trainable=False, + name="decoder_pos_embed", + ) + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], + int(self.num_patches**0.5), + add_cls_token=True, + )[None, ...] + self.decoder_pos_embed.assign(decoder_pos_embed) + + if self.built: + return + self.built = True + if getattr(self, "decoder_embed", None) is not None: + with tf.name_scope(self.decoder_embed.name): + self.decoder_embed.build([None, None, self.config.hidden_size]) + if getattr(self, "decoder_norm", None) is not None: + with tf.name_scope(self.decoder_norm.name): + self.decoder_norm.build([None, None, self.config.decoder_hidden_size]) + if getattr(self, "decoder_pred", None) is not None: + with tf.name_scope(self.decoder_pred.name): + self.decoder_pred.build([None, None, self.config.decoder_hidden_size]) + if getattr(self, "decoder_layers", None) is not None: + for layer in self.decoder_layers: + with tf.name_scope(layer.name): + layer.build(None) + + def interpolate_pos_encoding(self, embeddings) -> tf.Tensor: + """ + This method is a modified version of the interpolation function for ViT-mae model at the deocder, that + allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + # [batch_size, num_patches + 1, hidden_size] + _, num_positions, dim = shape_list(self.decoder_pos_embed) + + # -1 removes the class dimension since we later append it without interpolation + seq_len = shape_list(embeddings)[1] - 1 + num_positions = num_positions - 1 + + # Separation of class token and patch tokens + class_pos_embed = self.decoder_pos_embed[:, :1, :] + patch_pos_embed = self.decoder_pos_embed[:, 1:, :] + + # interpolate the position embeddings + patch_pos_embed = tf.image.resize( + images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)), + size=(1, seq_len), + method="bicubic", + ) + + # [1, seq_len, hidden_size] + patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim)) + # Adding the class token back + return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1) + + def call( + self, + hidden_states, + ids_restore, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + interpolate_pos_encoding=False, + ): + # embed tokens + x = self.decoder_embed(hidden_states) + # append mask tokens to sequence + mask_tokens = tf.tile( + self.mask_token, + (shape_list(x)[0], shape_list(ids_restore)[1] + 1 - shape_list(x)[1], 1), + ) + x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token + x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle + x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token + if interpolate_pos_encoding: + decoder_pos_embed = self.interpolate_pos_encoding(x) + else: + decoder_pos_embed = self.decoder_pos_embed + # add pos embed + hidden_states = x + decoder_pos_embed + # apply Transformer layers (blocks) + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + for i, layer_module in enumerate(self.decoder_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + head_mask=None, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.decoder_norm(hidden_states) + + # predictor projection + logits = self.decoder_pred(hidden_states) + + # remove cls token + logits = logits[:, 1:, :] + + if not return_dict: + return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None) + return TFViTMAEDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions) + + +@add_start_docstrings( + "The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.", + VIT_MAE_START_DOCSTRING, +) +class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.vit = TFViTMAEMainLayer(config, name="vit") + self.decoder = TFViTMAEDecoder( + config, + num_patches=self.vit.embeddings.num_patches, + name="decoder", + ) + + def get_input_embeddings(self): + return self.vit.get_input_embeddings() + + def _prune_heads(self, heads_to_prune): + raise NotImplementedError + + def patchify(self, pixel_values, interpolate_pos_encoding: bool = False): + """ + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`): + Pixel values. + interpolate_pos_encoding (`bool`, default `False`): + interpolation flag passed during the forward pass. + + Returns: + `tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Patchified pixel values. + """ + patch_size, num_channels = self.config.patch_size, self.config.num_channels + # make sure channels are last + if shape_list(pixel_values)[1] == num_channels: + pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1)) + + # sanity checks + if not interpolate_pos_encoding: + tf.debugging.assert_equal( + shape_list(pixel_values)[1], + shape_list(pixel_values)[2], + message="Make sure the pixel values have a squared size", + ) + tf.debugging.assert_equal( + shape_list(pixel_values)[1] % patch_size, + 0, + message="Make sure the pixel values have a size that is divisible by the patch size", + ) + tf.debugging.assert_equal( + shape_list(pixel_values)[3], + num_channels, + message=( + "Make sure the number of channels of the pixel values is equal to the one set in the configuration" + ), + ) + + # patchify + batch_size = shape_list(pixel_values)[0] + num_patches_h = shape_list(pixel_values)[1] // patch_size + num_patches_w = shape_list(pixel_values)[2] // patch_size + patchified_pixel_values = tf.reshape( + pixel_values, + (batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels), + ) + patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values) + patchified_pixel_values = tf.reshape( + patchified_pixel_values, + (batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels), + ) + return patchified_pixel_values + + def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None): + """ + Args: + patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Patchified pixel values. + original_image_size (`Tuple[int, int]`, *optional*): + Original image size. + + Returns: + `tf.Tensor` of shape `(batch_size, height, width, num_channels)`: + Pixel values. + """ + patch_size, num_channels = self.config.patch_size, self.config.num_channels + original_image_size = ( + original_image_size + if original_image_size is not None + else (self.config.image_size, self.config.image_size) + ) + original_height, original_width = original_image_size + num_patches_h = original_height // patch_size + num_patches_w = original_width // patch_size + # sanity check + tf.debugging.assert_equal( + num_patches_h * num_patches_w, + shape_list(patchified_pixel_values)[1], + message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}", + ) + + # unpatchify + batch_size = shape_list(patchified_pixel_values)[0] + patchified_pixel_values = tf.reshape( + patchified_pixel_values, + (batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels), + ) + patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values) + pixel_values = tf.reshape( + patchified_pixel_values, + (batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels), + ) + return pixel_values + + def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False): + """ + Args: + pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`): + Pixel values. + pred (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Predicted pixel values. + mask (`tf.Tensor` of shape `(batch_size, sequence_length)`): + Tensor indicating which patches are masked (1) and which are not (0). + interpolate_pos_encoding (`bool`, *optional*, default `False`): + interpolation flag passed during the forward pass. + + Returns: + `tf.Tensor`: Pixel reconstruction loss. + """ + target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + if self.config.norm_pix_loss: + mean = tf.reduce_mean(target, axis=-1, keepdims=True) + var = tf.math.reduce_variance(target, axis=-1, keepdims=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = tf.reduce_mean(loss, axis=-1) # [batch_size, num_patches], mean loss per patch + + loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches + loss = tf.reshape(loss, (1,)) + return loss + + @unpack_inputs + @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + pixel_values: TFModelInputType | None = None, + noise: Optional[tf.Tensor] = None, + head_mask: np.ndarray | tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + interpolate_pos_encoding: bool = False, + ) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, TFViTMAEForPreTraining + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base") + >>> model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> loss = outputs.loss + >>> mask = outputs.mask + >>> ids_restore = outputs.ids_restore + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vit( + pixel_values=pixel_values, + noise=noise, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + latent = outputs.last_hidden_state + ids_restore = outputs.ids_restore + mask = outputs.mask + + # [batch_size, num_patches, patch_size**2*3] + decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding) + logits = decoder_outputs.logits + + loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding) + + if not return_dict: + output = (logits, mask, ids_restore) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFViTMAEForPreTrainingOutput( + loss=loss, + logits=logits, + mask=mask, + ids_restore=ids_restore, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "vit", None) is not None: + with tf.name_scope(self.vit.name): + self.vit.build(None) + if getattr(self, "decoder", None) is not None: + with tf.name_scope(self.decoder.name): + self.decoder.build(None) + + +__all__ = ["TFViTMAEForPreTraining", "TFViTMAEModel", "TFViTMAEPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/vit_mae/modeling_vit_mae.py b/docs/transformers/build/lib/transformers/models/vit_mae/modeling_vit_mae.py new file mode 100644 index 0000000000000000000000000000000000000000..4636519ee67e519cb77d08f4897956e856c19b80 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit_mae/modeling_vit_mae.py @@ -0,0 +1,1163 @@ +# coding=utf-8 +# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ViT MAE (masked autoencoder) model.""" + +import collections.abc +from copy import deepcopy +from dataclasses import dataclass +from typing import Callable, Optional, Set, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_vit_mae import ViTMAEConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "ViTMAEConfig" +_CHECKPOINT_FOR_DOC = "facebook/vit-mae-base" + + +@dataclass +class ViTMAEModelOutput(ModelOutput): + """ + Class for ViTMAEModel's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Tensor indicating which patches are masked (1) and which are not (0). + ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Tensor containing the original index of the (shuffled) masked patches. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + mask: Optional[torch.LongTensor] = None + ids_restore: Optional[torch.LongTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class ViTMAEDecoderOutput(ModelOutput): + """ + Class for ViTMAEDecoder's outputs, with potential hidden states and attentions. + + Args: + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`): + Pixel reconstruction logits. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + logits: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class ViTMAEForPreTrainingOutput(ModelOutput): + """ + Class for ViTMAEForPreTraining's outputs, with potential hidden states and attentions. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`): + Pixel reconstruction loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`): + Pixel reconstruction logits. + mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Tensor indicating which patches are masked (1) and which are not (0). + ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Tensor containing the original index of the (shuffled) masked patches. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer + plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + mask: Optional[torch.LongTensor] = None + ids_restore: Optional[torch.LongTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False): + """ + Create 2D sin/cos positional embeddings. + + Args: + embed_dim (`int`): + Embedding dimension. + grid_size (`int`): + The grid height and width. + add_cls_token (`bool`, *optional*, defaults to `False`): + Whether or not to add a classification (CLS) token. + + Returns: + (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the + position embeddings (with or without classification token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if add_cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be even") + + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class ViTMAEEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. + + """ + + def __init__(self, config): + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.patch_embeddings = ViTMAEPatchEmbeddings(config) + self.num_patches = self.patch_embeddings.num_patches + # fixed sin-cos embedding + self.position_embeddings = nn.Parameter( + torch.zeros(1, self.num_patches + 1, config.hidden_size), requires_grad=False + ) + self.patch_size = config.patch_size + self.config = config + + def initialize_weights(self): + # initialize (and freeze) position embeddings by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed( + self.position_embeddings.shape[-1], int(self.patch_embeddings.num_patches**0.5), add_cls_token=True + ) + self.position_embeddings.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d) + w = self.patch_embeddings.projection.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range) + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def random_masking(self, sequence, noise=None): + """ + Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random + noise. + + Args: + sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`) + noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is + mainly used for testing purposes to control randomness and maintain the reproducibility + """ + batch_size, seq_length, dim = sequence.shape + len_keep = int(seq_length * (1 - self.config.mask_ratio)) + + if noise is None: + noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([batch_size, seq_length], device=sequence.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return sequence_unmasked, mask, ids_restore + + def forward(self, pixel_values, noise=None, interpolate_pos_encoding: bool = False): + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + if interpolate_pos_encoding: + position_embeddings = self.interpolate_pos_encoding(embeddings, height, width) + else: + position_embeddings = self.position_embeddings + + # add position embeddings w/o cls token + embeddings = embeddings + position_embeddings[:, 1:, :] + + # masking: length -> length * config.mask_ratio + embeddings, mask, ids_restore = self.random_masking(embeddings, noise) + + # append cls token + cls_token = self.cls_token + position_embeddings[:, :1, :] + cls_tokens = cls_token.expand(embeddings.shape[0], -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + return embeddings, mask, ids_restore + + +class ViTMAEPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values, interpolate_pos_encoding: bool = False): + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + + if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]): + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + x = self.projection(pixel_values).flatten(2).transpose(1, 2) + return x + + +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention ViT->ViTMAE +class ViTMAESelfAttention(nn.Module): + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMAE +class ViTMAESelfOutput(nn.Module): + """ + The residual connection is defined in ViTMAELayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMAE +class ViTMAEAttention(nn.Module): + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__() + self.attention = ViTMAESelfAttention(config) + self.output = ViTMAESelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->ViTMAE +class ViTMAEIntermediate(nn.Module): + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->ViTMAE +class ViTMAEOutput(nn.Module): + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMAE,VIT->VITMAE +class ViTMAELayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ViTMAEAttention(config) + self.intermediate = ViTMAEIntermediate(config) + self.output = ViTMAEOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViTMAE, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViTMAE, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMAE +class ViTMAEEncoder(nn.Module): + def __init__(self, config: ViTMAEConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ViTMAELayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ViTMAEPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTMAEConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _supports_sdpa = True + _supports_flash_attn_2 = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ViTMAEEmbeddings): + module.initialize_weights() + elif isinstance(module, ViTMAEDecoder): + module.mask_token.data.zero_() + module.decoder_pos_embed.data.zero_() + + +VIT_MAE_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ViTMAEConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VIT_MAE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + interpolate_pos_encoding (`bool`, *optional*, default `False`): + Whether to interpolate the pre-trained position encodings. This is mainly used to use the model on higher + resolution images. +""" + + +@add_start_docstrings( + "The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.", + VIT_MAE_START_DOCSTRING, +) +class ViTMAEModel(ViTMAEPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.embeddings = ViTMAEEmbeddings(config) + self.encoder = ViTMAEEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ViTMAEModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + noise: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Union[Tuple, ViTMAEModelOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, ViTMAEModel + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base") + >>> model = ViTMAEModel.from_pretrained("facebook/vit-mae-base") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output, mask, ids_restore = self.embeddings( + pixel_values, noise=noise, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + if not return_dict: + return (sequence_output, mask, ids_restore) + encoder_outputs[1:] + + return ViTMAEModelOutput( + last_hidden_state=sequence_output, + mask=mask, + ids_restore=ids_restore, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class ViTMAEDecoder(nn.Module): + def __init__(self, config, num_patches): + super().__init__() + self.decoder_embed = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size)) + self.decoder_pos_embed = nn.Parameter( + torch.zeros(1, num_patches + 1, config.decoder_hidden_size), requires_grad=False + ) # fixed sin-cos embedding + + decoder_config = deepcopy(config) + decoder_config.hidden_size = config.decoder_hidden_size + decoder_config.num_hidden_layers = config.decoder_num_hidden_layers + decoder_config.num_attention_heads = config.decoder_num_attention_heads + decoder_config.intermediate_size = config.decoder_intermediate_size + self.decoder_layers = nn.ModuleList( + [ViTMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)] + ) + + self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps) + self.decoder_pred = nn.Linear( + config.decoder_hidden_size, config.patch_size**2 * config.num_channels, bias=True + ) # encoder to decoder + self.gradient_checkpointing = False + self.config = config + self.initialize_weights(num_patches) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor: + """ + This method is a modified version of the interpolation function for ViT-mae model at the decoder, that + allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher + resolution images. + + Adapted from: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + # -1 removes the class dimension since we later append it without interpolation + embeddings_positions = embeddings.shape[1] - 1 + + # Separation of class token and patch tokens + class_pos_embed = self.decoder_pos_embed[:, :1] + patch_pos_embed = self.decoder_pos_embed[:, 1:] + + # To retain the final 3d tensor with the required dimensions + dim = self.decoder_pos_embed.shape[-1] + + # Increasing a dimension to enable bicubic interpolation + patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim) + + # permute to bring the dimension to be interpolated, to the last + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + # Interpolating the decoder position embeddings shape wrt embeddings shape i.e (x). + # we keep the second last dimension constant + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(patch_pos_embed.shape[-2], embeddings_positions), + mode="bicubic", + align_corners=False, + ) + + # Converting back to the original shape + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + # Adding the class token back + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def initialize_weights(self, num_patches): + # initialize (and freeze) position embeddings by sin-cos embedding + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], int(num_patches**0.5), add_cls_token=True + ) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.mask_token, std=self.config.initializer_range) + + def forward( + self, + hidden_states, + ids_restore, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + interpolate_pos_encoding: bool = False, + ): + # embed tokens + x = self.decoder_embed(hidden_states) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + # unshuffle + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device)) + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + # add pos embed + if interpolate_pos_encoding: + decoder_pos_embed = self.interpolate_pos_encoding(x) + else: + decoder_pos_embed = self.decoder_pos_embed + hidden_states = x + decoder_pos_embed + + # apply Transformer layers (blocks) + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + for i, layer_module in enumerate(self.decoder_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + None, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.decoder_norm(hidden_states) + + # predictor projection + logits = self.decoder_pred(hidden_states) + + # remove cls token + logits = logits[:, 1:, :] + + if not return_dict: + return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None) + return ViTMAEDecoderOutput( + logits=logits, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@add_start_docstrings( + """The ViTMAE Model transformer with the decoder on top for self-supervised pre-training. + + + + Note that we provide a script to pre-train this model on custom data in our [examples + directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining). + + + + """, + VIT_MAE_START_DOCSTRING, +) +class ViTMAEForPreTraining(ViTMAEPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.vit = ViTMAEModel(config) + self.decoder = ViTMAEDecoder(config, num_patches=self.vit.embeddings.num_patches) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.vit.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def patchify(self, pixel_values, interpolate_pos_encoding: bool = False): + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + interpolate_pos_encoding (`bool`, *optional*, default `False`): + interpolation flag passed during the forward pass. + + Returns: + `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Patchified pixel values. + """ + patch_size, num_channels = self.config.patch_size, self.config.num_channels + # sanity checks + if not interpolate_pos_encoding and ( + pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0 + ): + raise ValueError("Make sure the pixel values have a squared size that is divisible by the patch size") + if pixel_values.shape[1] != num_channels: + raise ValueError( + "Make sure the number of channels of the pixel values is equal to the one set in the configuration" + ) + + # patchify + batch_size = pixel_values.shape[0] + num_patches_h = pixel_values.shape[2] // patch_size + num_patches_w = pixel_values.shape[3] // patch_size + patchified_pixel_values = pixel_values.reshape( + batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size + ) + patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values) + patchified_pixel_values = patchified_pixel_values.reshape( + batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels + ) + return patchified_pixel_values + + def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None): + """ + Args: + patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Patchified pixel values. + original_image_size (`Tuple[int, int]`, *optional*): + Original image size. + + Returns: + `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: + Pixel values. + """ + patch_size, num_channels = self.config.patch_size, self.config.num_channels + original_image_size = ( + original_image_size + if original_image_size is not None + else (self.config.image_size, self.config.image_size) + ) + original_height, original_width = original_image_size + num_patches_h = original_height // patch_size + num_patches_w = original_width // patch_size + # sanity check + if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]: + raise ValueError( + f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}" + ) + + # unpatchify + batch_size = patchified_pixel_values.shape[0] + patchified_pixel_values = patchified_pixel_values.reshape( + batch_size, + num_patches_h, + num_patches_w, + patch_size, + patch_size, + num_channels, + ) + patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values) + pixel_values = patchified_pixel_values.reshape( + batch_size, + num_channels, + num_patches_h * patch_size, + num_patches_w * patch_size, + ) + return pixel_values + + def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False): + """ + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`: + Predicted pixel values. + mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Tensor indicating which patches are masked (1) and which are not (0). + interpolate_pos_encoding (`bool`, *optional*, default `False`): + interpolation flag passed during the forward pass. + + Returns: + `torch.FloatTensor`: Pixel reconstruction loss. + """ + target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + if self.config.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + @add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + noise: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + ) -> Union[Tuple, ViTMAEForPreTrainingOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, ViTMAEForPreTraining + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-base") + >>> model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> loss = outputs.loss + >>> mask = outputs.mask + >>> ids_restore = outputs.ids_restore + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vit( + pixel_values, + noise=noise, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + latent = outputs.last_hidden_state + ids_restore = outputs.ids_restore + mask = outputs.mask + + decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding) + logits = decoder_outputs.logits # shape (batch_size, num_patches, patch_size*patch_size*num_channels) + + loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding) + + if not return_dict: + output = (logits, mask, ids_restore) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ViTMAEForPreTrainingOutput( + loss=loss, + logits=logits, + mask=mask, + ids_restore=ids_restore, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["ViTMAEForPreTraining", "ViTMAELayer", "ViTMAEModel", "ViTMAEPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/vit_msn/__init__.py b/docs/transformers/build/lib/transformers/models/vit_msn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..547c5f4c04b912fe69a09470106c0d523df63931 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit_msn/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vit_msn import * + from .modeling_vit_msn import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/vit_msn/configuration_vit_msn.py b/docs/transformers/build/lib/transformers/models/vit_msn/configuration_vit_msn.py new file mode 100644 index 0000000000000000000000000000000000000000..cd47df3e9932e0a7e4fa1a9084976214fa9d2939 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit_msn/configuration_vit_msn.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ViT MSN model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class ViTMSNConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ViTMSNModel`]. It is used to instantiate an ViT + MSN model according to the specified arguments, defining the model architecture. Instantiating a configuration with + the defaults will yield a similar configuration to that of the ViT + [facebook/vit_msn_base](https://huggingface.co/facebook/vit_msn_base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + + Example: + + ```python + >>> from transformers import ViTMSNModel, ViTMSNConfig + + >>> # Initializing a ViT MSN vit-msn-base style configuration + >>> configuration = ViTConfig() + + >>> # Initializing a model from the vit-msn-base style configuration + >>> model = ViTMSNModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vit_msn" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-06, + image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + + +__all__ = ["ViTMSNConfig"] diff --git a/docs/transformers/build/lib/transformers/models/vit_msn/convert_msn_to_pytorch.py b/docs/transformers/build/lib/transformers/models/vit_msn/convert_msn_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..899c74f183205e9fdc18984a1f15e877bc64fe31 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit_msn/convert_msn_to_pytorch.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert ViT MSN checkpoints from the original repository: https://github.com/facebookresearch/msn""" + +import argparse +import json + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import ViTImageProcessor, ViTMSNConfig, ViTMSNModel +from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +torch.set_grad_enabled(False) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config, base_model=False): + rename_keys = [] + for i in range(config.num_hidden_layers): + # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms + rename_keys.append((f"module.blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight")) + rename_keys.append((f"module.blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias")) + rename_keys.append( + (f"module.blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight") + ) + rename_keys.append((f"module.blocks.{i}.attn.proj.bias", f"vit.encoder.layer.{i}.attention.output.dense.bias")) + rename_keys.append((f"module.blocks.{i}.norm2.weight", f"vit.encoder.layer.{i}.layernorm_after.weight")) + rename_keys.append((f"module.blocks.{i}.norm2.bias", f"vit.encoder.layer.{i}.layernorm_after.bias")) + rename_keys.append((f"module.blocks.{i}.mlp.fc1.weight", f"vit.encoder.layer.{i}.intermediate.dense.weight")) + rename_keys.append((f"module.blocks.{i}.mlp.fc1.bias", f"vit.encoder.layer.{i}.intermediate.dense.bias")) + rename_keys.append((f"module.blocks.{i}.mlp.fc2.weight", f"vit.encoder.layer.{i}.output.dense.weight")) + rename_keys.append((f"module.blocks.{i}.mlp.fc2.bias", f"vit.encoder.layer.{i}.output.dense.bias")) + + # projection layer + position embeddings + rename_keys.extend( + [ + ("module.cls_token", "vit.embeddings.cls_token"), + ("module.patch_embed.proj.weight", "vit.embeddings.patch_embeddings.projection.weight"), + ("module.patch_embed.proj.bias", "vit.embeddings.patch_embeddings.projection.bias"), + ("module.pos_embed", "vit.embeddings.position_embeddings"), + ] + ) + + if base_model: + # layernorm + pooler + rename_keys.extend( + [ + ("module.norm.weight", "layernorm.weight"), + ("module.norm.bias", "layernorm.bias"), + ] + ) + + # if just the base model, we should remove "vit" from all keys that start with "vit" + rename_keys = [(pair[0], pair[1][4:]) if pair[1].startswith("vit") else pair for pair in rename_keys] + else: + # layernorm + classification head + rename_keys.extend( + [ + ("norm.weight", "vit.layernorm.weight"), + ("norm.bias", "vit.layernorm.bias"), + ("head.weight", "classifier.weight"), + ("head.bias", "classifier.bias"), + ] + ) + + return rename_keys + + +# we split up the matrix of each encoder layer into queries, keys and values +def read_in_q_k_v(state_dict, config, base_model=False): + for i in range(config.num_hidden_layers): + if base_model: + prefix = "" + else: + prefix = "vit." + # read in weights + bias of input projection layer (in timm, this is a single matrix + bias) + in_proj_weight = state_dict.pop(f"module.blocks.{i}.attn.qkv.weight") + in_proj_bias = state_dict.pop(f"module.blocks.{i}.attn.qkv.bias") + # next, add query, keys and values (in that order) to the state dict + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.weight"] = in_proj_weight[ + : config.hidden_size, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.query.bias"] = in_proj_bias[: config.hidden_size] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.weight"] = in_proj_weight[ + config.hidden_size : config.hidden_size * 2, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.key.bias"] = in_proj_bias[ + config.hidden_size : config.hidden_size * 2 + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.weight"] = in_proj_weight[ + -config.hidden_size :, : + ] + state_dict[f"{prefix}encoder.layer.{i}.attention.attention.value.bias"] = in_proj_bias[-config.hidden_size :] + + +def remove_classification_head_(state_dict): + ignore_keys = ["head.weight", "head.bias"] + for k in ignore_keys: + state_dict.pop(k, None) + + +def remove_projection_head(state_dict): + # projection head is used in the self-supervised pre-training in MSN, + # for downstream task it's not needed. + ignore_keys = [ + "module.fc.fc1.weight", + "module.fc.fc1.bias", + "module.fc.bn1.weight", + "module.fc.bn1.bias", + "module.fc.bn1.running_mean", + "module.fc.bn1.running_var", + "module.fc.bn1.num_batches_tracked", + "module.fc.fc2.weight", + "module.fc.fc2.bias", + "module.fc.bn2.weight", + "module.fc.bn2.bias", + "module.fc.bn2.running_mean", + "module.fc.bn2.running_var", + "module.fc.bn2.num_batches_tracked", + "module.fc.fc3.weight", + "module.fc.fc3.bias", + ] + for k in ignore_keys: + state_dict.pop(k, None) + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +def convert_vit_msn_checkpoint(checkpoint_url, pytorch_dump_folder_path): + config = ViTMSNConfig() + config.num_labels = 1000 + + repo_id = "datasets/huggingface/label-files" + filename = "imagenet-1k-id2label.json" + id2label = json.load(open(hf_hub_download(repo_id, filename), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + + if "s16" in checkpoint_url: + config.hidden_size = 384 + config.intermediate_size = 1536 + config.num_attention_heads = 6 + elif "l16" in checkpoint_url: + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + config.hidden_dropout_prob = 0.1 + elif "b4" in checkpoint_url: + config.patch_size = 4 + elif "l7" in checkpoint_url: + config.patch_size = 7 + config.hidden_size = 1024 + config.intermediate_size = 4096 + config.num_hidden_layers = 24 + config.num_attention_heads = 16 + config.hidden_dropout_prob = 0.1 + + model = ViTMSNModel(config) + + state_dict = torch.hub.load_state_dict_from_url(checkpoint_url, map_location="cpu")["target_encoder"] + + image_processor = ViTImageProcessor(size=config.image_size) + + remove_projection_head(state_dict) + rename_keys = create_rename_keys(config, base_model=True) + + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + read_in_q_k_v(state_dict, config, base_model=True) + + model.load_state_dict(state_dict) + model.eval() + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + image = Image.open(requests.get(url, stream=True).raw) + image_processor = ViTImageProcessor( + size=config.image_size, image_mean=IMAGENET_DEFAULT_MEAN, image_std=IMAGENET_DEFAULT_STD + ) + inputs = image_processor(images=image, return_tensors="pt") + + # forward pass + torch.manual_seed(2) + outputs = model(**inputs) + last_hidden_state = outputs.last_hidden_state + + # The following Colab Notebook was used to generate these outputs: + # https://colab.research.google.com/gist/sayakpaul/3672419a04f5997827503fd84079bdd1/scratchpad.ipynb + if "s16" in checkpoint_url: + expected_slice = torch.tensor([[-1.0915, -1.4876, -1.1809]]) + elif "b16" in checkpoint_url: + expected_slice = torch.tensor([[14.2889, -18.9045, 11.7281]]) + elif "l16" in checkpoint_url: + expected_slice = torch.tensor([[41.5028, -22.8681, 45.6475]]) + elif "b4" in checkpoint_url: + expected_slice = torch.tensor([[-4.3868, 5.2932, -0.4137]]) + else: + expected_slice = torch.tensor([[-0.1792, -0.6465, 2.4263]]) + + # verify logits + assert torch.allclose(last_hidden_state[:, 0, :3], expected_slice, atol=1e-4) + + print(f"Saving model to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + + print(f"Saving image processor to {pytorch_dump_folder_path}") + image_processor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--checkpoint_url", + default="https://dl.fbaipublicfiles.com/msn/vits16_800ep.pth.tar", + type=str, + help="URL of the checkpoint you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + + args = parser.parse_args() + convert_vit_msn_checkpoint(args.checkpoint_url, args.pytorch_dump_folder_path) diff --git a/docs/transformers/build/lib/transformers/models/vit_msn/modeling_vit_msn.py b/docs/transformers/build/lib/transformers/models/vit_msn/modeling_vit_msn.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5a3d56ba607f2319ca2132e2f6b6ad9341bafd --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vit_msn/modeling_vit_msn.py @@ -0,0 +1,741 @@ +# coding=utf-8 +# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ViT MSN (masked siamese network) model.""" + +import collections.abc +from typing import Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_vit_msn import ViTMSNConfig + + +logger = logging.get_logger(__name__) + + +_CONFIG_FOR_DOC = "ViTMSNConfig" +_CHECKPOINT_FOR_DOC = "facebook/vit-msn-small" + + +class ViTMSNEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = ViTMSNPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.patch_size + self.config = config + + # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.ViTPatchEmbeddings with ViT->ViTMSN +class ViTMSNPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->ViTMSN +class ViTMSNSelfAttention(nn.Module): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTMSN +class ViTMSNSelfOutput(nn.Module): + """ + The residual connection is defined in ViTMSNLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->ViTMSN +class ViTMSNAttention(nn.Module): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.attention = ViTMSNSelfAttention(config) + self.output = ViTMSNSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTMSN +class ViTMSNIntermediate(nn.Module): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTOutput with ViT->ViTMSN +class ViTMSNOutput(nn.Module): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->ViTMSN, VIT->VITMSN +class ViTMSNLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ViTMSNAttention(config) + self.intermediate = ViTMSNIntermediate(config) + self.output = ViTMSNOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in ViTMSN, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViTMSN, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->ViTMSN +class ViTMSNEncoder(nn.Module): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ViTMSNLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ViTMSNPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTMSNConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"] + _supports_sdpa = True + _supports_flash_attn_2 = True + + # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211 + # when creating pre-training scripts. + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ViTMSNEmbeddings): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() + if module.mask_token is not None: + module.mask_token.data.zero_() + + +VIT_MSN_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`ViTMSNConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VIT_MSN_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ViTMSN Model outputting raw hidden-states without any specific head on top.", + VIT_MSN_START_DOCSTRING, +) +class ViTMSNModel(ViTMSNPreTrainedModel): + def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False): + super().__init__(config) + self.config = config + + self.embeddings = ViTMSNEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = ViTMSNEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ViTMSNPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, BaseModelOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, ViTMSNModel + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-msn-small") + >>> model = ViTMSNModel.from_pretrained("facebook/vit-msn-small") + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + + if not return_dict: + head_outputs = (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +# Caution: We don't have the weights for the classification head yet. This class +# is here for the users that are interested to fine-tune the base model (ViTMSNModel). +@add_start_docstrings( + """ + ViTMSN Model with an image classification head on top e.g. for ImageNet. + """, + VIT_MSN_START_DOCSTRING, +) +class ViTMSNForImageClassification(ViTMSNPreTrainedModel): + def __init__(self, config: ViTMSNConfig) -> None: + super().__init__(config) + + self.num_labels = config.num_labels + self.vit = ViTMSNModel(config) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, ViTMSNForImageClassification + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> torch.manual_seed(2) # doctest: +IGNORE_RESULT + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("facebook/vit-msn-small") + >>> model = ViTMSNForImageClassification.from_pretrained("facebook/vit-msn-small") + + >>> inputs = image_processor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + >>> # model predicts one of the 1000 ImageNet classes + >>> predicted_label = logits.argmax(-1).item() + >>> print(model.config.id2label[predicted_label]) + tusker + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["ViTMSNModel", "ViTMSNForImageClassification", "ViTMSNPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/vitdet/__init__.py b/docs/transformers/build/lib/transformers/models/vitdet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f96b2fdf7d6247455c115ef40700507076fe6a12 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitdet/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vitdet import * + from .modeling_vitdet import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/vitdet/configuration_vitdet.py b/docs/transformers/build/lib/transformers/models/vitdet/configuration_vitdet.py new file mode 100644 index 0000000000000000000000000000000000000000..cd91dce9b2961e92112dd2589213c0979a03ad3d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitdet/configuration_vitdet.py @@ -0,0 +1,156 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VitDet model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class VitDetConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VitDetModel`]. It is used to instantiate an + VitDet model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the VitDet + [google/vitdet-base-patch16-224](https://huggingface.co/google/vitdet-base-patch16-224) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + Ratio of mlp hidden dim to embedding dim. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + pretrain_image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image during pretraining. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Stochastic depth rate. + window_block_indices (`List[int]`, *optional*, defaults to `[]`): + List of indices of blocks that should have window attention instead of regular global self-attention. + residual_block_indices (`List[int]`, *optional*, defaults to `[]`): + List of indices of blocks that should have an extra residual block after the MLP. + use_absolute_position_embeddings (`bool`, *optional*, defaults to `True`): + Whether to add absolute position embeddings to the patch embeddings. + use_relative_position_embeddings (`bool`, *optional*, defaults to `False`): + Whether to add relative position embeddings to the attention maps. + window_size (`int`, *optional*, defaults to 0): + The size of the attention window. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + + ```python + >>> from transformers import VitDetConfig, VitDetModel + + >>> # Initializing a VitDet configuration + >>> configuration = VitDetConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = VitDetModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vitdet" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + mlp_ratio=4, + hidden_act="gelu", + dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-6, + image_size=224, + pretrain_image_size=224, + patch_size=16, + num_channels=3, + qkv_bias=True, + drop_path_rate=0.0, + window_block_indices=[], + residual_block_indices=[], + use_absolute_position_embeddings=True, + use_relative_position_embeddings=False, + window_size=0, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.hidden_act = hidden_act + self.dropout_prob = dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.pretrain_image_size = pretrain_image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.drop_path_rate = drop_path_rate + self.window_block_indices = window_block_indices + self.residual_block_indices = residual_block_indices + self.use_absolute_position_embeddings = use_absolute_position_embeddings + self.use_relative_position_embeddings = use_relative_position_embeddings + self.window_size = window_size + + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, self.num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +__all__ = ["VitDetConfig"] diff --git a/docs/transformers/build/lib/transformers/models/vitdet/modeling_vitdet.py b/docs/transformers/build/lib/transformers/models/vitdet/modeling_vitdet.py new file mode 100644 index 0000000000000000000000000000000000000000..3d740522884d07b03c182944f6a6644f25c225de --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitdet/modeling_vitdet.py @@ -0,0 +1,883 @@ +# coding=utf-8 +# Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ViTDet backbone.""" + +import collections.abc +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput, BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_vitdet import VitDetConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "VitDetConfig" + + +class VitDetEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) to be consumed by a Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.pretrain_image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + if config.use_absolute_position_embeddings: + # Initialize absolute positional embedding with pretrain image size. + num_positions = num_patches + 1 + self.position_embeddings = nn.Parameter(torch.zeros(1, num_positions, config.hidden_size)) + else: + self.position_embeddings = None + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def get_absolute_positions(self, abs_pos_embeddings, has_cls_token, height, width): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the + original embeddings. + + Args: + abs_pos_embeddings (`torch.Tensor`): + Absolute positional embeddings with (1, num_position, num_channels). + has_cls_token (`bool`): + If true, has 1 embedding in abs_pos_embeddings for cls token. + height (`int`): + Height of input image tokens. + width (`int`): + Width of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, height, width, num_channels) + """ + if has_cls_token: + abs_pos_embeddings = abs_pos_embeddings[:, 1:] + num_position = abs_pos_embeddings.shape[1] + size = int(math.sqrt(num_position)) # This is a constant and can be recorded as such in the ONNX export. + if size * size != num_position: + raise ValueError("Absolute position embeddings must be a square number.") + + if torch.jit.is_tracing() or (size != height or size != width): + # nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace. + new_abs_pos_embeddings = nn.functional.interpolate( + abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2), + size=(height, width), + mode="bicubic", + align_corners=False, + ) + + return new_abs_pos_embeddings.permute(0, 2, 3, 1) + else: + return abs_pos_embeddings.reshape(1, height, width, -1) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values) + + if self.position_embeddings is not None: + # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels) + embeddings = embeddings.permute(0, 2, 3, 1) + # add position embeddings + embeddings = embeddings + self.get_absolute_positions( + self.position_embeddings, True, embeddings.shape[1], embeddings.shape[2] + ) + # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width) + embeddings = embeddings.permute(0, 3, 1, 2) + + return embeddings + + +@torch.jit.script_if_tracing # nn.functional.interpolate's `size` needs to be dynamic. +def get_rel_pos(q_size, k_size, rel_pos): + """ + Get relative positional embeddings according to the relative positions of query and key sizes. + + Args: + q_size (`int`): + Size of query q. + k_size (`int`): + Size of key k. + rel_pos (`torch.Tensor`): + Relative position embeddings (num_embeddings, num_channels). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel position embeddings. + rel_pos_resized = nn.functional.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_relative_positions(attn, queries, rel_pos_h, rel_pos_w, q_size, k_size): + """ + Calculate decomposed Relative Positional Embeddings as introduced in + [MViT2](https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py). + + Args: + attn (`torch.Tensor`): + Attention map. + queries (`torch.Tensor`): + Query q in the attention layer with shape (batch_size, queries_height * queries_width, num_channels). + rel_pos_h (`torch.Tensor`): + Relative position embeddings (Lh, num_channels) for height axis. + rel_pos_w (`torch.Tensor`): + Relative position embeddings (Lw, num_channels) for width axis. + q_size (`Tuple[int]`): + Spatial sequence size of query q with (queries_height, queries_width). + k_size (`Tuple[int]`): + Spatial sequence size of key k with (keys_height, keys_width). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + queries_height, queries_width = q_size + keys_height, keys_width = k_size + relative_height = get_rel_pos(queries_height, keys_height, rel_pos_h) + relative_width = get_rel_pos(queries_width, keys_width, rel_pos_w) + + batch_size, _, dim = queries.shape + r_q = queries.reshape(batch_size, queries_height, queries_width, dim) + relative_height = torch.einsum("bhwc,hkc->bhwk", r_q, relative_height) + relative_weight = torch.einsum("bhwc,wkc->bhwk", r_q, relative_width) + + attn = ( + attn.view(batch_size, queries_height, queries_width, keys_height, keys_width) + + relative_height[:, :, :, :, None] + + relative_weight[:, :, :, None, :] + ).view(batch_size, queries_height * queries_width, keys_height * keys_width) + + return attn + + +class VitDetAttention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__(self, config, input_size=None): + """ + Args: + config (`VitDetConfig`): + Model configuration. + input_size (`Tuple[int]`, *optional*): + Input resolution, only required in case relative position embeddings are added. + """ + super().__init__() + + dim = config.hidden_size + num_heads = config.num_attention_heads + + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=config.qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_relative_position_embeddings = config.use_relative_position_embeddings + if self.use_relative_position_embeddings: + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, hidden_state, output_attentions=False): + batch_size, height, width, _ = hidden_state.shape + # qkv with shape (3, batch_size, num_heads, height * width, num_channels) + qkv = self.qkv(hidden_state).reshape(batch_size, height * width, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # queries, keys and values have shape (batch_size * num_heads, height * width, num_channels) + queries, keys, values = qkv.reshape(3, batch_size * self.num_heads, height * width, -1).unbind(0) + + attention_scores = (queries * self.scale) @ keys.transpose(-2, -1) + + if self.use_relative_position_embeddings: + attention_scores = add_decomposed_relative_positions( + attention_scores, queries, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) + ) + + attention_probs = attention_scores.softmax(dim=-1) + + hidden_state = attention_probs @ values + hidden_state = hidden_state.view(batch_size, self.num_heads, height, width, -1) + hidden_state = hidden_state.permute(0, 2, 3, 1, 4) + hidden_state = hidden_state.reshape(batch_size, height, width, -1) + hidden_state = self.proj(hidden_state) + + if output_attentions: + attention_probs = attention_probs.reshape( + batch_size, self.num_heads, attention_probs.shape[-2], attention_probs.shape[-1] + ) + outputs = (hidden_state, attention_probs) + else: + outputs = (hidden_state,) + + return outputs + + +# Copied from transformers.models.beit.modeling_beit.drop_path +def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +# Copied from transformers.models.beit.modeling_beit.BeitDropPath +class VitDetDropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return "p={}".format(self.drop_prob) + + +class VitDetLayerNorm(nn.Module): + """ + A LayerNorm variant, popularized by Transformers, that performs point-wise mean and variance normalization over the + channel dimension for inputs that have shape (batch_size, channels, height, width). + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class VitDetResBottleneckBlock(nn.Module): + """ + The standard bottleneck residual block without the last activation layer. It contains 3 conv layers with kernels + 1x1, 3x3, 1x1. + """ + + def __init__(self, config, in_channels, out_channels, bottleneck_channels): + """ + Args: + config (`VitDetConfig`): + Model configuration. + in_channels (`int`): + Number of input channels. + out_channels (`int`): + Number of output channels. + bottleneck_channels (`int`): + Number of output channels for the 3x3 "bottleneck" conv layers. + """ + super().__init__() + self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, 1, bias=False) + self.norm1 = VitDetLayerNorm(bottleneck_channels) + self.act1 = ACT2FN[config.hidden_act] + + self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1, bias=False) + self.norm2 = VitDetLayerNorm(bottleneck_channels) + self.act2 = ACT2FN[config.hidden_act] + + self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, 1, bias=False) + self.norm3 = VitDetLayerNorm(out_channels) + + def forward(self, x): + out = x + for layer in self.children(): + out = layer(out) + + out = x + out + return out + + +class VitDetMlp(nn.Module): + def __init__(self, config, in_features: int, hidden_features: int) -> None: + super().__init__() + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = ACT2FN[config.hidden_act] + self.fc2 = nn.Linear(hidden_features, in_features) + self.drop = nn.Dropout(config.dropout_prob) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + + return x + + +def window_partition(hidden_state, window_size): + """ + Partition into non-overlapping windows with padding if needed. + + Args: + hidden_state (`torch.Tensor`): + Input tokens with [batch_size, height, width, num_channels]. + window_size (`int`): + Window size. + + Returns: + `tuple(torch.FloatTensor)` comprising various elements: + - windows: windows after partition with [batch_size * num_windows, window_size, window_size, num_channels]. + - (padded_height, padded_width): padded height and width before partition + """ + batch_size, height, width, num_channels = hidden_state.shape + + pad_height = (window_size - height % window_size) % window_size + pad_width = (window_size - width % window_size) % window_size + + # Noop in case pad_width == 0 and pad_height == 0. + hidden_state = nn.functional.pad(hidden_state, (0, 0, 0, pad_width, 0, pad_height)) + + padded_height, padded_width = height + pad_height, width + pad_width + + hidden_state = hidden_state.view( + batch_size, padded_height // window_size, window_size, padded_width // window_size, window_size, num_channels + ) + windows = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels) + return windows, (padded_height, padded_width) + + +def window_unpartition(windows, window_size, pad_height_width, height_width): + """ + Window unpartition into original sequences and removing padding. + + Args: + windows (`torch.Tensor`): + Input tokens with [batch_size * num_windows, window_size, window_size, num_channels]. + window_size (`int`): + Window size. + pad_height_width (`Tuple[int]`): + Padded height and width (padded_height, padded_width). + height_width (`Tuple[int]`): + Original height and width before padding. + + Returns: + hidden_state: unpartitioned sequences with [batch_size, height, width, num_channels]. + """ + padded_height, padded_width = pad_height_width + height, width = height_width + batch_size = windows.shape[0] // (padded_height * padded_width // window_size // window_size) + hidden_state = windows.view( + batch_size, padded_height // window_size, padded_width // window_size, window_size, window_size, -1 + ) + hidden_state = hidden_state.permute(0, 1, 3, 2, 4, 5).contiguous() + hidden_state = hidden_state.view(batch_size, padded_height, padded_width, -1) + + # We always have height <= padded_height and width <= padded_width + hidden_state = hidden_state[:, :height, :width, :].contiguous() + return hidden_state + + +class VitDetLayer(nn.Module): + """This corresponds to the Block class in the original implementation.""" + + def __init__( + self, config: VitDetConfig, drop_path_rate: float = 0, window_size: int = 0, use_residual_block: bool = False + ) -> None: + super().__init__() + + dim = config.hidden_size + + image_size = config.image_size + image_size = image_size if isinstance(image_size, (list, tuple)) else (image_size, image_size) + + patch_size = config.patch_size + patch_size = patch_size if isinstance(patch_size, (list, tuple)) else (patch_size, patch_size) + + input_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1]) + self.norm1 = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.attention = VitDetAttention( + config, input_size=input_size if window_size == 0 else (window_size, window_size) + ) + + self.drop_path = VitDetDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() + self.norm2 = nn.LayerNorm(dim, eps=config.layer_norm_eps) + self.mlp = VitDetMlp(config=config, in_features=dim, hidden_features=int(dim * config.mlp_ratio)) + + self.window_size = window_size + + self.use_residual_block = use_residual_block + if self.use_residual_block: + # Use a residual block with bottleneck channel as dim // 2 + self.residual = VitDetResBottleneckBlock( + config=config, + in_channels=dim, + out_channels=dim, + bottleneck_channels=dim // 2, + ) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + hidden_states = hidden_states.permute(0, 2, 3, 1) + + shortcut = hidden_states + + hidden_states = self.norm1(hidden_states) + + # Window partition + if self.window_size > 0: + height, width = hidden_states.shape[1], hidden_states.shape[2] + hidden_states, pad_height_width = window_partition(hidden_states, self.window_size) + + self_attention_outputs = self.attention( + hidden_states, + output_attentions=output_attentions, + ) + hidden_states = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # Reverse window partition + if self.window_size > 0: + hidden_states = window_unpartition(hidden_states, self.window_size, pad_height_width, (height, width)) + + # first residual connection + hidden_states = shortcut + self.drop_path(hidden_states) + + hidden_states = hidden_states + self.drop_path(self.mlp(self.norm2(hidden_states))) + + hidden_states = hidden_states.permute(0, 3, 1, 2) + + if self.use_residual_block: + hidden_states = self.residual(hidden_states) + + outputs = (hidden_states,) + outputs + + return outputs + + +class VitDetEncoder(nn.Module): + def __init__(self, config: VitDetConfig) -> None: + super().__init__() + self.config = config + depth = config.num_hidden_layers + + # stochastic depth decay rule + drop_path_rate = [x.item() for x in torch.linspace(0, config.drop_path_rate, depth, device="cpu")] + + layers = [] + for i in range(depth): + layers.append( + VitDetLayer( + config, + drop_path_rate=drop_path_rate[i], + window_size=config.window_size if i in config.window_block_indices else 0, + use_residual_block=i in config.residual_block_indices, + ) + ) + + self.layer = nn.ModuleList(layers) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +def caffe2_msra_fill(module: nn.Module) -> None: + """ + Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. Also initializes `module.bias` to 0. + + Source: https://detectron2.readthedocs.io/en/latest/_modules/fvcore/nn/weight_init.html. + + Args: + module (torch.nn.Module): module to initialize. + """ + nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + +class VitDetPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = VitDetConfig + base_model_prefix = "vitdet" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = [] + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + elif isinstance(module, VitDetEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + elif isinstance(module, VitDetAttention) and self.config.use_relative_position_embeddings: + module.rel_pos_h.data = nn.init.trunc_normal_( + module.rel_pos_h.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ) + module.rel_pos_w.data = nn.init.trunc_normal_( + module.rel_pos_w.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ) + + elif isinstance(module, VitDetResBottleneckBlock): + for layer in [module.conv1, module.conv2, module.conv3]: + caffe2_msra_fill(layer) + for layer in [module.norm1, module.norm2]: + layer.weight.data.fill_(1.0) + layer.bias.data.zero_() + # zero init last norm layer. + module.norm3.weight.data.zero_() + module.norm3.bias.data.zero_() + + +VITDET_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`VitDetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VITDET_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`] + for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare VitDet Transformer model outputting raw hidden-states without any specific head on top.", + VITDET_START_DOCSTRING, +) +class VitDetModel(VitDetPreTrainedModel): + def __init__(self, config: VitDetConfig): + super().__init__(config) + self.config = config + + self.embeddings = VitDetEmbeddings(config) + self.encoder = VitDetEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> VitDetEmbeddings: + return self.embeddings.projection + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(VITDET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + """ + Returns: + + Examples: + + ```python + >>> from transformers import VitDetConfig, VitDetModel + >>> import torch + + >>> config = VitDetConfig() + >>> model = VitDetModel(config) + + >>> pixel_values = torch.randn(1, 3, 224, 224) + + >>> with torch.no_grad(): + ... outputs = model(pixel_values) + + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 768, 14, 14] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + + if not return_dict: + return (sequence_output,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """ + ViTDet backbone, to be used with frameworks like Mask R-CNN. + """, + VITDET_START_DOCSTRING, +) +class VitDetBackbone(VitDetPreTrainedModel, BackboneMixin): + def __init__(self, config): + super().__init__(config) + super()._init_backbone(config) + + self.embeddings = VitDetEmbeddings(config) + self.encoder = VitDetEncoder(config) + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + + # initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> VitDetEmbeddings: + return self.embeddings.projection + + @add_start_docstrings_to_model_forward(VITDET_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> BackboneOutput: + """ + Returns: + + Examples: + + ```python + >>> from transformers import VitDetConfig, VitDetBackbone + >>> import torch + + >>> config = VitDetConfig() + >>> model = VitDetBackbone(config) + + >>> pixel_values = torch.randn(1, 3, 224, 224) + + >>> with torch.no_grad(): + ... outputs = model(pixel_values) + + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 768, 14, 14] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = ["VitDetModel", "VitDetPreTrainedModel", "VitDetBackbone"] diff --git a/docs/transformers/build/lib/transformers/models/vitmatte/__init__.py b/docs/transformers/build/lib/transformers/models/vitmatte/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b87cea448ab5296902d91d351d06252c52a1386 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitmatte/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vitmatte import * + from .image_processing_vitmatte import * + from .modeling_vitmatte import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/vitmatte/configuration_vitmatte.py b/docs/transformers/build/lib/transformers/models/vitmatte/configuration_vitmatte.py new file mode 100644 index 0000000000000000000000000000000000000000..b9f78043306b72e3951ecb16bb1bfcb868abac20 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitmatte/configuration_vitmatte.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VitMatte model configuration""" + +import copy +from typing import List + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto.configuration_auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class VitMatteConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of [`VitMatteForImageMatting`]. It is used to + instantiate a ViTMatte model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the ViTMatte + [hustvl/vitmatte-small-composition-1k](https://huggingface.co/hustvl/vitmatte-small-composition-1k) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitDetConfig()`): + The configuration of the backbone model. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + hidden_size (`int`, *optional*, defaults to 384): + The number of input channels of the decoder. + batch_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the batch norm layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + convstream_hidden_sizes (`List[int]`, *optional*, defaults to `[48, 96, 192]`): + The output channels of the ConvStream module. + fusion_hidden_sizes (`List[int]`, *optional*, defaults to `[256, 128, 64, 32]`): + The output channels of the Fusion blocks. + + Example: + + ```python + >>> from transformers import VitMatteConfig, VitMatteForImageMatting + + >>> # Initializing a ViTMatte hustvl/vitmatte-small-composition-1k style configuration + >>> configuration = VitMatteConfig() + + >>> # Initializing a model (with random weights) from the hustvl/vitmatte-small-composition-1k style configuration + >>> model = VitMatteForImageMatting(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vitmatte" + + def __init__( + self, + backbone_config: PretrainedConfig = None, + backbone=None, + use_pretrained_backbone=False, + use_timm_backbone=False, + backbone_kwargs=None, + hidden_size: int = 384, + batch_norm_eps: float = 1e-5, + initializer_range: float = 0.02, + convstream_hidden_sizes: List[int] = [48, 96, 192], + fusion_hidden_sizes: List[int] = [256, 128, 64, 32], + **kwargs, + ): + super().__init__(**kwargs) + + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `VitDet` backbone.") + backbone_config = CONFIG_MAPPING["vitdet"](out_features=["stage4"]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + self.batch_norm_eps = batch_norm_eps + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.convstream_hidden_sizes = convstream_hidden_sizes + self.fusion_hidden_sizes = fusion_hidden_sizes + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output["backbone_config"] = self.backbone_config.to_dict() + output["model_type"] = self.__class__.model_type + return output + + +__all__ = ["VitMatteConfig"] diff --git a/docs/transformers/build/lib/transformers/models/vitmatte/convert_vitmatte_to_hf.py b/docs/transformers/build/lib/transformers/models/vitmatte/convert_vitmatte_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..5153e1faf525ea767e81e120a341acdfc60e5373 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitmatte/convert_vitmatte_to_hf.py @@ -0,0 +1,170 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert VitMatte checkpoints from the original repository. + +URL: https://github.com/hustvl/ViTMatte +""" + +import argparse + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import VitDetConfig, VitMatteConfig, VitMatteForImageMatting, VitMatteImageProcessor + + +def get_config(model_name): + hidden_size = 384 if "small" in model_name else 768 + num_attention_heads = 6 if "small" in model_name else 12 + + backbone_config = VitDetConfig( + num_channels=4, + image_size=512, + pretrain_image_size=224, + patch_size=16, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_absolute_position_embeddings=True, + use_relative_position_embeddings=True, + window_size=14, + # 2, 5, 8, 11 for global attention + window_block_indices=[0, 1, 3, 4, 6, 7, 9, 10], + residual_block_indices=[2, 5, 8, 11], + out_features=["stage12"], + ) + + return VitMatteConfig(backbone_config=backbone_config, hidden_size=hidden_size) + + +# here we list all keys to be renamed (original name on the left, our name on the right) +def create_rename_keys(config): + rename_keys = [] + + # fmt: off + # stem + rename_keys.append(("backbone.pos_embed", "backbone.embeddings.position_embeddings")) + rename_keys.append(("backbone.patch_embed.proj.weight", "backbone.embeddings.projection.weight")) + rename_keys.append(("backbone.patch_embed.proj.bias", "backbone.embeddings.projection.bias")) + # fmt: on + + return rename_keys + + +def rename_key(dct, old, new): + val = dct.pop(old) + dct[new] = val + + +def convert_vitmatte_checkpoint(model_name, pytorch_dump_folder_path, push_to_hub): + config = get_config(model_name) + + # load original state dict + model_name_to_filename = { + "vitmatte-small-composition-1k": "ViTMatte_S_Com.pth", + "vitmatte-base-composition-1k": "ViTMatte_B_Com.pth", + "vitmatte-small-distinctions-646": "ViTMatte_S_DIS.pth", + "vitmatte-base-distinctions-646": "ViTMatte_B_DIS.pth", + } + + filename = model_name_to_filename[model_name] + filepath = hf_hub_download(repo_id="nielsr/vitmatte-checkpoints", filename=filename, repo_type="model") + state_dict = torch.load(filepath, map_location="cpu", weights_only=True) + + # rename keys + for key in state_dict.copy().keys(): + val = state_dict.pop(key) + if "backbone.blocks" in key: + key = key.replace("backbone.blocks", "backbone.encoder.layer") + if "attn" in key: + key = key.replace("attn", "attention") + if "fusion_blks" in key: + key = key.replace("fusion_blks", "fusion_blocks") + if "bn" in key: + key = key.replace("bn", "batch_norm") + state_dict[key] = val + + # rename keys + rename_keys = create_rename_keys(config) + for src, dest in rename_keys: + rename_key(state_dict, src, dest) + + # create model + processor = VitMatteImageProcessor() + model = VitMatteForImageMatting(config) + model.eval() + + # load state dict + model.load_state_dict(state_dict) + + # verify on dummy image + trimap + url = "https://github.com/hustvl/ViTMatte/blob/main/demo/bulb_rgb.png?raw=true" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + url = "https://github.com/hustvl/ViTMatte/blob/main/demo/bulb_trimap.png?raw=true" + trimap = Image.open(requests.get(url, stream=True).raw) + + pixel_values = processor(images=image, trimaps=trimap.convert("L"), return_tensors="pt").pixel_values + + with torch.no_grad(): + alphas = model(pixel_values).alphas + + if model_name == "vitmatte-small-composition-1k": + expected_slice = torch.tensor([[0.9977, 0.9987, 0.9990], [0.9980, 0.9998, 0.9998], [0.9983, 0.9998, 0.9998]]) + elif model_name == "vitmatte-base-composition-1k": + expected_slice = torch.tensor([[0.9972, 0.9971, 0.9981], [0.9948, 0.9987, 0.9994], [0.9963, 0.9992, 0.9995]]) + elif model_name == "vitmatte-small-distinctions-646": + expected_slice = torch.tensor([[0.9880, 0.9970, 0.9972], [0.9960, 0.9996, 0.9997], [0.9963, 0.9996, 0.9997]]) + elif model_name == "vitmatte-base-distinctions-646": + expected_slice = torch.tensor([[0.9963, 0.9998, 0.9999], [0.9995, 1.0000, 1.0000], [0.9992, 0.9999, 1.0000]]) + + assert torch.allclose(alphas[0, 0, :3, :3], expected_slice, atol=1e-4) + print("Looks ok!") + + if pytorch_dump_folder_path is not None: + print(f"Saving model and processor of {model_name} to {pytorch_dump_folder_path}") + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + print(f"Pushing model and processor for {model_name} to hub") + model.push_to_hub(f"hustvl/{model_name}") + processor.push_to_hub(f"hustvl/{model_name}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="vitmatte-small-composition-1k", + type=str, + choices=[ + "vitmatte-small-composition-1k", + "vitmatte-base-composition-1k", + "vitmatte-small-distinctions-646", + "vitmatte-base-distinctions-646", + ], + help="Name of the VitMatte model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_vitmatte_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/docs/transformers/build/lib/transformers/models/vitmatte/image_processing_vitmatte.py b/docs/transformers/build/lib/transformers/models/vitmatte/image_processing_vitmatte.py new file mode 100644 index 0000000000000000000000000000000000000000..4c3b06e08815e2b8ebac1864da691900de660156 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitmatte/image_processing_vitmatte.py @@ -0,0 +1,272 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for ViTMatte.""" + +from typing import List, Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import pad, to_channel_dimension_format +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + get_image_size, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import TensorType, filter_out_non_signature_kwargs, logging + + +logger = logging.get_logger(__name__) + + +class VitMatteImageProcessor(BaseImageProcessor): + r""" + Constructs a ViTMatte image processor. + + Args: + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter in the + `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + do_pad (`bool`, *optional*, defaults to `True`): + Whether to pad the image to make the width and height divisible by `size_divisibility`. Can be overridden + by the `do_pad` parameter in the `preprocess` method. + size_divisibility (`int`, *optional*, defaults to 32): + The width and height of the image will be padded to be divisible by this number. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: bool = True, + size_divisibility: int = 32, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.do_rescale = do_rescale + self.do_normalize = do_normalize + self.do_pad = do_pad + self.rescale_factor = rescale_factor + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + self.size_divisibility = size_divisibility + + def pad_image( + self, + image: np.ndarray, + size_divisibility: int = 32, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """ + Args: + image (`np.ndarray`): + Image to pad. + size_divisibility (`int`, *optional*, defaults to 32): + The width and height of the image will be padded to be divisible by this number. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + height, width = get_image_size(image, input_data_format) + + pad_height = 0 if height % size_divisibility == 0 else size_divisibility - height % size_divisibility + pad_width = 0 if width % size_divisibility == 0 else size_divisibility - width % size_divisibility + if pad_width + pad_height > 0: + padding = ((0, pad_height), (0, pad_width)) + image = pad(image, padding=padding, data_format=data_format, input_data_format=input_data_format) + + if data_format is not None: + image = to_channel_dimension_format(image, data_format, input_data_format) + + return image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + trimaps: ImageInput, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + do_pad: Optional[bool] = None, + size_divisibility: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + trimaps (`ImageInput`): + Trimap to preprocess. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + do_pad (`bool`, *optional*, defaults to `self.do_pad`): + Whether to pad the image. + size_divisibility (`int`, *optional*, defaults to `self.size_divisibility`): + The size divisibility to pad the image to if `do_pad` is set to `True`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_pad = do_pad if do_pad is not None else self.do_pad + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + size_divisibility = size_divisibility if size_divisibility is not None else self.size_divisibility + + images = make_list_of_images(images) + trimaps = make_list_of_images(trimaps, expected_ndims=2) + + if not valid_images(trimaps): + raise ValueError( + "Invalid trimap type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_pad=do_pad, + size_divisibility=size_divisibility, + ) + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + trimaps = [to_numpy_array(trimap) for trimap in trimaps] + + if do_rescale and is_scaled_image(images[0]): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + if do_rescale: + images = [ + self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + for image in images + ] + trimaps = [ + self.rescale(image=trimap, scale=rescale_factor, input_data_format=input_data_format) + for trimap in trimaps + ] + + if do_normalize: + images = [ + self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + for image in images + ] + + # concatenate images and trimaps + images = [ + np.concatenate([image, np.expand_dims(trimap, axis=-1)], axis=-1) for image, trimap in zip(images, trimaps) + ] + + if do_pad: + images = [ + self.pad_image(image, size_divisibility=size_divisibility, input_data_format=input_data_format) + for image in images + ] + + images = [ + to_channel_dimension_format(image=image, channel_dim=data_format, input_channel_dim=input_data_format) + for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["VitMatteImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/vitmatte/modeling_vitmatte.py b/docs/transformers/build/lib/transformers/models/vitmatte/modeling_vitmatte.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3aa5b883b2483bb9070002fbc7aba3ee1ffe01 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitmatte/modeling_vitmatte.py @@ -0,0 +1,341 @@ +# coding=utf-8 +# Copyright 2023 HUST-VL and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ViTMatte model.""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn + +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from ...utils.backbone_utils import load_backbone +from .configuration_vitmatte import VitMatteConfig + + +# General docstring +_CONFIG_FOR_DOC = "VitMatteConfig" + + +@dataclass +class ImageMattingOutput(ModelOutput): + """ + Class for outputs of image matting models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Loss. + alphas (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Estimated alpha values. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + alphas: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +class VitMattePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = VitMatteConfig + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = [] + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + + +class VitMatteBasicConv3x3(nn.Module): + """ + Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers. + """ + + def __init__(self, config, in_channels, out_channels, stride=2, padding=1): + super().__init__() + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=False, + ) + self.batch_norm = nn.BatchNorm2d(out_channels, eps=config.batch_norm_eps) + self.relu = nn.ReLU() + + def forward(self, hidden_state): + hidden_state = self.conv(hidden_state) + hidden_state = self.batch_norm(hidden_state) + hidden_state = self.relu(hidden_state) + + return hidden_state + + +class VitMatteConvStream(nn.Module): + """ + Simple ConvStream containing a series of basic conv3x3 layers to extract detail features. + """ + + def __init__(self, config): + super().__init__() + + # We use a default in-case there isn't a backbone config set. This is for backwards compatibility and + # to enable loading HF backbone models. + in_channels = 4 + if config.backbone_config is not None: + in_channels = config.backbone_config.num_channels + + out_channels = config.convstream_hidden_sizes + + self.convs = nn.ModuleList() + self.conv_chans = [in_channels] + out_channels + + for i in range(len(self.conv_chans) - 1): + in_chan_ = self.conv_chans[i] + out_chan_ = self.conv_chans[i + 1] + self.convs.append(VitMatteBasicConv3x3(config, in_chan_, out_chan_)) + + def forward(self, pixel_values): + out_dict = {"detailed_feature_map_0": pixel_values} + embeddings = pixel_values + for i in range(len(self.convs)): + embeddings = self.convs[i](embeddings) + name_ = "detailed_feature_map_" + str(i + 1) + out_dict[name_] = embeddings + + return out_dict + + +class VitMatteFusionBlock(nn.Module): + """ + Simple fusion block to fuse features from ConvStream and Plain Vision Transformer. + """ + + def __init__(self, config, in_channels, out_channels): + super().__init__() + self.conv = VitMatteBasicConv3x3(config, in_channels, out_channels, stride=1, padding=1) + + def forward(self, features, detailed_feature_map): + upscaled_features = nn.functional.interpolate(features, scale_factor=2, mode="bilinear", align_corners=False) + out = torch.cat([detailed_feature_map, upscaled_features], dim=1) + out = self.conv(out) + + return out + + +class VitMatteHead(nn.Module): + """ + Simple Matting Head, containing only conv3x3 and conv1x1 layers. + """ + + def __init__(self, config): + super().__init__() + + in_channels = config.fusion_hidden_sizes[-1] + mid_channels = 16 + + self.matting_convs = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(mid_channels), + nn.ReLU(True), + nn.Conv2d(mid_channels, 1, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, hidden_state): + hidden_state = self.matting_convs(hidden_state) + + return hidden_state + + +class VitMatteDetailCaptureModule(nn.Module): + """ + Simple and lightweight Detail Capture Module for ViT Matting. + """ + + def __init__(self, config): + super().__init__() + if len(config.fusion_hidden_sizes) != len(config.convstream_hidden_sizes) + 1: + raise ValueError( + "The length of fusion_hidden_sizes should be equal to the length of convstream_hidden_sizes + 1." + ) + + self.config = config + self.convstream = VitMatteConvStream(config) + self.conv_chans = self.convstream.conv_chans + + self.fusion_blocks = nn.ModuleList() + self.fusion_channels = [config.hidden_size] + config.fusion_hidden_sizes + + for i in range(len(self.fusion_channels) - 1): + self.fusion_blocks.append( + VitMatteFusionBlock( + config=config, + in_channels=self.fusion_channels[i] + self.conv_chans[-(i + 1)], + out_channels=self.fusion_channels[i + 1], + ) + ) + + self.matting_head = VitMatteHead(config) + + def forward(self, features, pixel_values): + detail_features = self.convstream(pixel_values) + for i in range(len(self.fusion_blocks)): + detailed_feature_map_name = "detailed_feature_map_" + str(len(self.fusion_blocks) - i - 1) + features = self.fusion_blocks[i](features, detail_features[detailed_feature_map_name]) + + alphas = torch.sigmoid(self.matting_head(features)) + + return alphas + + +VITMATTE_START_DOCSTRING = r""" + Parameters: + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + config ([`UperNetConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VITMATTE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`VitMatteImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers in case the backbone has them. See + `attentions` under returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers of the backbone. See `hidden_states` under + returned tensors for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """ViTMatte framework leveraging any vision backbone e.g. for ADE20k, CityScapes.""", + VITMATTE_START_DOCSTRING, +) +class VitMatteForImageMatting(VitMattePreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.config = config + + self.backbone = load_backbone(config) + self.decoder = VitMatteDetailCaptureModule(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VITMATTE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=ImageMattingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + ): + """ + labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): + Ground truth image matting for computing the loss. + + Returns: + + Examples: + + ```python + >>> from transformers import VitMatteImageProcessor, VitMatteForImageMatting + >>> import torch + >>> from PIL import Image + >>> from huggingface_hub import hf_hub_download + + >>> processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k") + >>> model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k") + + >>> filepath = hf_hub_download( + ... repo_id="hf-internal-testing/image-matting-fixtures", filename="image.png", repo_type="dataset" + ... ) + >>> image = Image.open(filepath).convert("RGB") + >>> filepath = hf_hub_download( + ... repo_id="hf-internal-testing/image-matting-fixtures", filename="trimap.png", repo_type="dataset" + ... ) + >>> trimap = Image.open(filepath).convert("L") + + >>> # prepare image + trimap for the model + >>> inputs = processor(images=image, trimaps=trimap, return_tensors="pt") + + >>> with torch.no_grad(): + ... alphas = model(**inputs).alphas + >>> print(alphas.shape) + torch.Size([1, 1, 640, 960]) + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + loss = None + if labels is not None: + raise NotImplementedError("Training is not yet supported") + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, output_hidden_states=output_hidden_states, output_attentions=output_attentions + ) + + features = outputs.feature_maps[-1] + alphas = self.decoder(features, pixel_values) + + if not return_dict: + output = (alphas,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageMattingOutput( + loss=loss, + alphas=alphas, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["VitMattePreTrainedModel", "VitMatteForImageMatting"] diff --git a/docs/transformers/build/lib/transformers/models/vitpose/__init__.py b/docs/transformers/build/lib/transformers/models/vitpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a57524cce2143830e294c1920e395436d670add --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitpose/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vitpose import * + from .image_processing_vitpose import * + from .modeling_vitpose import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/vitpose/configuration_vitpose.py b/docs/transformers/build/lib/transformers/models/vitpose/configuration_vitpose.py new file mode 100644 index 0000000000000000000000000000000000000000..aba8fec7ae41a1ebd7caf27f582b3b1913cad06c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitpose/configuration_vitpose.py @@ -0,0 +1,126 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VitPose model configuration""" + +from typing import Optional + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import verify_backbone_config_arguments +from ..auto.configuration_auto import CONFIG_MAPPING + + +logger = logging.get_logger(__name__) + + +class VitPoseConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VitPoseForPoseEstimation`]. It is used to instantiate a + VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the VitPose + [usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + backbone_config (`PretrainedConfig` or `dict`, *optional*, defaults to `VitPoseBackboneConfig()`): + The configuration of the backbone model. Currently, only `backbone_config` with `vitpose_backbone` as `model_type` is supported. + backbone (`str`, *optional*): + Name of backbone to use when `backbone_config` is `None`. If `use_pretrained_backbone` is `True`, this + will load the corresponding pretrained weights from the timm or transformers library. If `use_pretrained_backbone` + is `False`, this loads the backbone's config and uses that to initialize the backbone with random weights. + use_pretrained_backbone (`bool`, *optional*, defaults to `False`): + Whether to use pretrained weights for the backbone. + use_timm_backbone (`bool`, *optional*, defaults to `False`): + Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers + library. + backbone_kwargs (`dict`, *optional*): + Keyword arguments to be passed to AutoBackbone when loading from a checkpoint + e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_factor (`int`, *optional*, defaults to 4): + Factor to upscale the feature maps coming from the ViT backbone. + use_simple_decoder (`bool`, *optional*, defaults to `True`): + Whether to use a `VitPoseSimpleDecoder` to decode the feature maps from the backbone into heatmaps. Otherwise it uses `VitPoseClassicDecoder`. + + + Example: + + ```python + >>> from transformers import VitPoseConfig, VitPoseForPoseEstimation + + >>> # Initializing a VitPose configuration + >>> configuration = VitPoseConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = VitPoseForPoseEstimation(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vitpose" + + def __init__( + self, + backbone_config: Optional[PretrainedConfig] = None, + backbone: Optional[str] = None, + use_pretrained_backbone: bool = False, + use_timm_backbone: bool = False, + backbone_kwargs: Optional[dict] = None, + initializer_range: float = 0.02, + scale_factor: int = 4, + use_simple_decoder: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + if use_pretrained_backbone: + logger.info( + "`use_pretrained_backbone` is `True`. For the pure inference purpose of VitPose weight do not set this value." + ) + if use_timm_backbone: + raise ValueError("use_timm_backbone set `True` is not supported at the moment.") + + if backbone_config is None and backbone is None: + logger.info("`backbone_config` is `None`. Initializing the config with the default `VitPose` backbone.") + backbone_config = CONFIG_MAPPING["vitpose_backbone"](out_indices=[4]) + elif isinstance(backbone_config, dict): + backbone_model_type = backbone_config.get("model_type") + config_class = CONFIG_MAPPING[backbone_model_type] + backbone_config = config_class.from_dict(backbone_config) + + verify_backbone_config_arguments( + use_timm_backbone=use_timm_backbone, + use_pretrained_backbone=use_pretrained_backbone, + backbone=backbone, + backbone_config=backbone_config, + backbone_kwargs=backbone_kwargs, + ) + + self.backbone_config = backbone_config + self.backbone = backbone + self.use_pretrained_backbone = use_pretrained_backbone + self.use_timm_backbone = use_timm_backbone + self.backbone_kwargs = backbone_kwargs + + self.initializer_range = initializer_range + self.scale_factor = scale_factor + self.use_simple_decoder = use_simple_decoder + + +__all__ = ["VitPoseConfig"] diff --git a/docs/transformers/build/lib/transformers/models/vitpose/convert_vitpose_to_hf.py b/docs/transformers/build/lib/transformers/models/vitpose/convert_vitpose_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..e4666751a1078d487756f9681cd9fd21fbc58212 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitpose/convert_vitpose_to_hf.py @@ -0,0 +1,428 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert VitPose checkpoints from the original repository. + +URL: https://github.com/vitae-transformer/vitpose + +Notebook to get the original logits: https://colab.research.google.com/drive/1QDX_2POTpl6JaZAV2WIFjuiqDsDwiqMZ?usp=sharing. +""" + +import argparse +import os +import re + +import requests +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from transformers import VitPoseBackboneConfig, VitPoseConfig, VitPoseForPoseEstimation, VitPoseImageProcessor + + +ORIGINAL_TO_CONVERTED_KEY_MAPPING = { + r"patch_embed.proj": "embeddings.patch_embeddings.projection", + r"pos_embed": "embeddings.position_embeddings", + r"blocks": "encoder.layer", + r"attn.proj": "attention.output.dense", + r"attn": "attention.self", + r"norm1": "layernorm_before", + r"norm2": "layernorm_after", + r"last_norm": "layernorm", + r"keypoint_head": "head", + r"final_layer": "conv", +} + +MODEL_TO_FILE_NAME_MAPPING = { + # VitPose models, simple decoder + "vitpose-base-simple": "vitpose-b-simple.pth", + # VitPose models, classic decoder + "vitpose-base": "vitpose-b.pth", + # VitPose models, COCO-AIC-MPII + "vitpose-base-coco-aic-mpii": "vitpose_base_coco_aic_mpii.pth", + # VitPose+ models + "vitpose-plus-small": "vitpose+_small.pth", + "vitpose-plus-base": "vitpose+_base.pth", + "vitpose-plus-large": "vitpose+_large.pth", + "vitpose-plus-huge": "vitpose+_huge.pth", +} + + +def get_config(model_name): + if "plus" in model_name: + num_experts = 6 + if "small" in model_name: + part_features = 96 + out_indices = [12] + elif "base" in model_name: + part_features = 192 + out_indices = [12] + elif "large" in model_name: + part_features = 256 + out_indices = [24] + elif "huge" in model_name: + part_features = 320 + out_indices = [32] + else: + raise ValueError(f"Model {model_name} not supported") + else: + num_experts = 1 + part_features = 0 + + # size of the architecture + if "small" in model_name: + hidden_size = 384 + num_hidden_layers = 12 + num_attention_heads = 12 + elif "large" in model_name: + hidden_size = 1024 + num_hidden_layers = 24 + num_attention_heads = 16 + elif "huge" in model_name: + hidden_size = 1280 + num_hidden_layers = 32 + num_attention_heads = 16 + + backbone_config = VitPoseBackboneConfig( + out_indices=out_indices, + hidden_size=hidden_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_experts=num_experts, + part_features=part_features, + ) + + use_simple_decoder = "simple" in model_name + + edges = [ + [15, 13], + [13, 11], + [16, 14], + [14, 12], + [11, 12], + [5, 11], + [6, 12], + [5, 6], + [5, 7], + [6, 8], + [7, 9], + [8, 10], + [1, 2], + [0, 1], + [0, 2], + [1, 3], + [2, 4], + [3, 5], + [4, 6], + ] + id2label = { + 0: "Nose", + 1: "L_Eye", + 2: "R_Eye", + 3: "L_Ear", + 4: "R_Ear", + 5: "L_Shoulder", + 6: "R_Shoulder", + 7: "L_Elbow", + 8: "R_Elbow", + 9: "L_Wrist", + 10: "R_Wrist", + 11: "L_Hip", + 12: "R_Hip", + 13: "L_Knee", + 14: "R_Knee", + 15: "L_Ankle", + 16: "R_Ankle", + } + + label2id = {v: k for k, v in id2label.items()} + + config = VitPoseConfig( + backbone_config=backbone_config, + num_labels=17, + use_simple_decoder=use_simple_decoder, + edges=edges, + id2label=id2label, + label2id=label2id, + ) + + return config + + +def convert_old_keys_to_new_keys(state_dict_keys: dict = None): + """ + This function should be applied only once, on the concatenated keys to efficiently rename using + the key mappings. + """ + output_dict = {} + if state_dict_keys is not None: + old_text = "\n".join(state_dict_keys) + new_text = old_text + for pattern, replacement in ORIGINAL_TO_CONVERTED_KEY_MAPPING.items(): + if replacement is None: + new_text = re.sub(pattern, "", new_text) # an empty line + continue + new_text = re.sub(pattern, replacement, new_text) + output_dict = dict(zip(old_text.split("\n"), new_text.split("\n"))) + return output_dict + + +# We will verify our results on a COCO image +def prepare_img(): + url = "http://images.cocodataset.org/val2017/000000000139.jpg" + image = Image.open(requests.get(url, stream=True).raw) + return image + + +@torch.no_grad() +def write_model(model_name, model_path, push_to_hub, check_logits=True): + # ------------------------------------------------------------ + # Vision model params and config + # ------------------------------------------------------------ + + # params from config + config = get_config(model_name) + + # ------------------------------------------------------------ + # Convert weights + # ------------------------------------------------------------ + + # load original state_dict + filename = MODEL_TO_FILE_NAME_MAPPING[model_name] + print(f"Fetching all parameters from the checkpoint at {filename}...") + + checkpoint_path = hf_hub_download( + repo_id="nielsr/vitpose-original-checkpoints", filename=filename, repo_type="model" + ) + + print("Converting model...") + original_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["state_dict"] + all_keys = list(original_state_dict.keys()) + new_keys = convert_old_keys_to_new_keys(all_keys) + + dim = config.backbone_config.hidden_size + + state_dict = {} + for key in all_keys: + new_key = new_keys[key] + value = original_state_dict[key] + + if re.search("associate_heads", new_key) or re.search("backbone.cls_token", new_key): + # This associated_heads is concept of auxiliary head so does not require in inference stage. + # backbone.cls_token is optional forward function for dynamically change of size, see detail in https://github.com/ViTAE-Transformer/ViTPose/issues/34 + pass + elif re.search("qkv", new_key): + state_dict[new_key.replace("self.qkv", "attention.query")] = value[:dim] + state_dict[new_key.replace("self.qkv", "attention.key")] = value[dim : dim * 2] + state_dict[new_key.replace("self.qkv", "attention.value")] = value[-dim:] + elif re.search("head", new_key) and not config.use_simple_decoder: + # Pattern for deconvolution layers + deconv_pattern = r"deconv_layers\.(0|3)\.weight" + new_key = re.sub(deconv_pattern, lambda m: f"deconv{int(m.group(1)) // 3 + 1}.weight", new_key) + # Pattern for batch normalization layers + bn_patterns = [ + (r"deconv_layers\.(\d+)\.weight", r"batchnorm\1.weight"), + (r"deconv_layers\.(\d+)\.bias", r"batchnorm\1.bias"), + (r"deconv_layers\.(\d+)\.running_mean", r"batchnorm\1.running_mean"), + (r"deconv_layers\.(\d+)\.running_var", r"batchnorm\1.running_var"), + (r"deconv_layers\.(\d+)\.num_batches_tracked", r"batchnorm\1.num_batches_tracked"), + ] + + for pattern, replacement in bn_patterns: + if re.search(pattern, new_key): + # Convert the layer number to the correct batch norm index + layer_num = int(re.search(pattern, key).group(1)) + bn_num = layer_num // 3 + 1 + new_key = re.sub(pattern, replacement.replace(r"\1", str(bn_num)), new_key) + state_dict[new_key] = value + else: + state_dict[new_key] = value + + print("Loading the checkpoint in a Vitpose model.") + model = VitPoseForPoseEstimation(config) + model.eval() + model.load_state_dict(state_dict) + print("Checkpoint loaded successfully.") + + # create image processor + image_processor = VitPoseImageProcessor() + + # verify image processor + image = prepare_img() + boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]] + pixel_values = image_processor(images=image, boxes=boxes, return_tensors="pt").pixel_values + + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="vitpose_batch_data.pt", repo_type="dataset") + original_pixel_values = torch.load(filepath, map_location="cpu", weights_only=True)["img"] + # we allow for a small difference in the pixel values due to the original repository using cv2 + assert torch.allclose(pixel_values, original_pixel_values, atol=1e-1) + + dataset_index = torch.tensor([0]) + + with torch.no_grad(): + print("Shape of original_pixel_values: ", original_pixel_values.shape) + print("First values of original_pixel_values: ", original_pixel_values[0, 0, :3, :3]) + + # first forward pass + outputs = model(original_pixel_values, dataset_index=dataset_index) + output_heatmap = outputs.heatmaps + + print("Shape of output_heatmap: ", output_heatmap.shape) + print("First values: ", output_heatmap[0, 0, :3, :3]) + + # second forward pass (flipped) + # this is done since the model uses `flip_test=True` in its test config + original_pixel_values_flipped = torch.flip(original_pixel_values, [3]) + outputs_flipped = model( + original_pixel_values_flipped, + dataset_index=dataset_index, + flip_pairs=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]), + ) + output_flipped_heatmap = outputs_flipped.heatmaps + + outputs.heatmaps = (output_heatmap + output_flipped_heatmap) * 0.5 + + # Verify pose_results + pose_results = image_processor.post_process_pose_estimation(outputs, boxes=boxes)[0] + + if check_logits: + # Simple decoder checkpoints + if model_name == "vitpose-base-simple": + assert torch.allclose( + pose_results[1]["keypoints"][0], + torch.tensor([3.98180511e02, 1.81808380e02]), + atol=5e-2, + ) + assert torch.allclose( + pose_results[1]["scores"][0], + torch.tensor([8.66642594e-01]), + atol=5e-2, + ) + # Classic decoder checkpoints + elif model_name == "vitpose-base": + assert torch.allclose( + pose_results[1]["keypoints"][0], + torch.tensor([3.9807913e02, 1.8182812e02]), + atol=5e-2, + ) + assert torch.allclose( + pose_results[1]["scores"][0], + torch.tensor([8.8235235e-01]), + atol=5e-2, + ) + # COCO-AIC-MPII checkpoints + elif model_name == "vitpose-base-coco-aic-mpii": + assert torch.allclose( + pose_results[1]["keypoints"][0], + torch.tensor([3.98305542e02, 1.81741592e02]), + atol=5e-2, + ) + assert torch.allclose( + pose_results[1]["scores"][0], + torch.tensor([8.69966745e-01]), + atol=5e-2, + ) + # VitPose+ models + elif model_name == "vitpose-plus-small": + assert torch.allclose( + pose_results[1]["keypoints"][0], + torch.tensor([398.1597, 181.6902]), + atol=5e-2, + ) + assert torch.allclose( + pose_results[1]["scores"][0], + torch.tensor(0.9051), + atol=5e-2, + ) + elif model_name == "vitpose-plus-base": + assert torch.allclose( + pose_results[1]["keypoints"][0], + torch.tensor([3.98201294e02, 1.81728302e02]), + atol=5e-2, + ) + assert torch.allclose( + pose_results[1]["scores"][0], + torch.tensor([8.75046968e-01]), + atol=5e-2, + ) + elif model_name == "vitpose-plus-large": + assert torch.allclose( + pose_results[1]["keypoints"][0], + torch.tensor([398.1409, 181.7412]), + atol=5e-2, + ) + assert torch.allclose( + pose_results[1]["scores"][0], + torch.tensor(0.8746), + atol=5e-2, + ) + elif model_name == "vitpose-plus-huge": + assert torch.allclose( + pose_results[1]["keypoints"][0], + torch.tensor([398.2079, 181.8026]), + atol=5e-2, + ) + assert torch.allclose( + pose_results[1]["scores"][0], + torch.tensor(0.8693), + atol=5e-2, + ) + else: + raise ValueError("Model not supported") + print("Conversion successfully done.") + + if model_path is not None: + os.makedirs(model_path, exist_ok=True) + model.save_pretrained(model_path) + image_processor.save_pretrained(model_path) + + if push_to_hub: + print(f"Pushing model and image processor for {model_name} to hub") + # we created a community organization on the hub for this model + # maintained by the Transformers team + model.push_to_hub(f"usyd-community/{model_name}") + image_processor.push_to_hub(f"usyd-community/{model_name}") + + +def main(): + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="vitpose-base-simple", + choices=MODEL_TO_FILE_NAME_MAPPING.keys(), + type=str, + help="Name of the VitPose model you'd like to convert.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to store the converted model." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + parser.add_argument( + "--check_logits", action="store_false", help="Whether or not to verify the logits of the converted model." + ) + + args = parser.parse_args() + write_model( + model_path=args.pytorch_dump_folder_path, + model_name=args.model_name, + push_to_hub=args.push_to_hub, + check_logits=args.check_logits, + ) + + +if __name__ == "__main__": + main() diff --git a/docs/transformers/build/lib/transformers/models/vitpose/image_processing_vitpose.py b/docs/transformers/build/lib/transformers/models/vitpose/image_processing_vitpose.py new file mode 100644 index 0000000000000000000000000000000000000000..387b7225473b61c1b505c3af617a4cd419887eb9 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitpose/image_processing_vitpose.py @@ -0,0 +1,684 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for VitPose.""" + +import itertools +import math +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature +from ...image_transforms import to_channel_dimension_format +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + infer_channel_dimension_format, + is_scaled_image, + make_list_of_images, + to_numpy_array, + valid_images, +) +from ...utils import TensorType, is_scipy_available, is_torch_available, is_vision_available, logging + + +if is_torch_available(): + import torch + +if is_vision_available(): + import PIL + +if is_scipy_available(): + from scipy.linalg import inv + from scipy.ndimage import affine_transform, gaussian_filter + +if TYPE_CHECKING: + from .modeling_vitpose import VitPoseEstimatorOutput + +logger = logging.get_logger(__name__) + + +# inspired by https://github.com/ViTAE-Transformer/ViTPose/blob/d5216452796c90c6bc29f5c5ec0bdba94366768a/mmpose/datasets/datasets/base/kpt_2d_sview_rgb_img_top_down_dataset.py#L132 +def box_to_center_and_scale( + box: Union[Tuple, List, np.ndarray], + image_width: int, + image_height: int, + normalize_factor: float = 200.0, + padding_factor: float = 1.25, +): + """ + Encodes a bounding box in COCO format into (center, scale). + + Args: + box (`Tuple`, `List`, or `np.ndarray`): + Bounding box in COCO format (top_left_x, top_left_y, width, height). + image_width (`int`): + Image width. + image_height (`int`): + Image height. + normalize_factor (`float`): + Width and height scale factor. + padding_factor (`float`): + Bounding box padding factor. + + Returns: + tuple: A tuple containing center and scale. + + - `np.ndarray` [float32](2,): Center of the bbox (x, y). + - `np.ndarray` [float32](2,): Scale of the bbox width & height. + """ + + top_left_x, top_left_y, width, height = box[:4] + aspect_ratio = image_width / image_height + center = np.array([top_left_x + width * 0.5, top_left_y + height * 0.5], dtype=np.float32) + + if width > aspect_ratio * height: + height = width * 1.0 / aspect_ratio + elif width < aspect_ratio * height: + width = height * aspect_ratio + + scale = np.array([width / normalize_factor, height / normalize_factor], dtype=np.float32) + scale = scale * padding_factor + + return center, scale + + +def coco_to_pascal_voc(bboxes: np.ndarray) -> np.ndarray: + """ + Converts bounding boxes from the COCO format to the Pascal VOC format. + + In other words, converts from (top_left_x, top_left_y, width, height) format + to (top_left_x, top_left_y, bottom_right_x, bottom_right_y). + + Args: + bboxes (`np.ndarray` of shape `(batch_size, 4)): + Bounding boxes in COCO format. + + Returns: + `np.ndarray` of shape `(batch_size, 4) in Pascal VOC format. + """ + bboxes[:, 2] = bboxes[:, 2] + bboxes[:, 0] - 1 + bboxes[:, 3] = bboxes[:, 3] + bboxes[:, 1] - 1 + + return bboxes + + +def get_keypoint_predictions(heatmaps: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get keypoint predictions from score maps. + + Args: + heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width)`): + Model predicted heatmaps. + + Returns: + tuple: A tuple containing aggregated results. + + - coords (`np.ndarray` of shape `(batch_size, num_keypoints, 2)`): + Predicted keypoint location. + - scores (`np.ndarray` of shape `(batch_size, num_keypoints, 1)`): + Scores (confidence) of the keypoints. + """ + if not isinstance(heatmaps, np.ndarray): + raise ValueError("Heatmaps should be np.ndarray") + if heatmaps.ndim != 4: + raise ValueError("Heatmaps should be 4-dimensional") + + batch_size, num_keypoints, _, width = heatmaps.shape + heatmaps_reshaped = heatmaps.reshape((batch_size, num_keypoints, -1)) + idx = np.argmax(heatmaps_reshaped, 2).reshape((batch_size, num_keypoints, 1)) + scores = np.amax(heatmaps_reshaped, 2).reshape((batch_size, num_keypoints, 1)) + + preds = np.tile(idx, (1, 1, 2)).astype(np.float32) + preds[:, :, 0] = preds[:, :, 0] % width + preds[:, :, 1] = preds[:, :, 1] // width + + preds = np.where(np.tile(scores, (1, 1, 2)) > 0.0, preds, -1) + return preds, scores + + +def post_dark_unbiased_data_processing(coords: np.ndarray, batch_heatmaps: np.ndarray, kernel: int = 3) -> np.ndarray: + """DARK post-pocessing. Implemented by unbiased_data_processing. + + Paper references: + - Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020). + - Zhang et al. Distribution-Aware Coordinate Representation for Human Pose Estimation (CVPR 2020). + + Args: + coords (`np.ndarray` of shape `(num_persons, num_keypoints, 2)`): + Initial coordinates of human pose. + batch_heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width)`): + Batched heatmaps as predicted by the model. + A batch_size of 1 is used for the bottom up paradigm where all persons share the same heatmap. + A batch_size of `num_persons` is used for the top down paradigm where each person has its own heatmaps. + kernel (`int`, *optional*, defaults to 3): + Gaussian kernel size (K) for modulation. + + Returns: + `np.ndarray` of shape `(num_persons, num_keypoints, 2)` ): + Refined coordinates. + """ + batch_size, num_keypoints, height, width = batch_heatmaps.shape + num_coords = coords.shape[0] + if not (batch_size == 1 or batch_size == num_coords): + raise ValueError("The batch size of heatmaps should be 1 or equal to the batch size of coordinates.") + radius = int((kernel - 1) // 2) + batch_heatmaps = np.array( + [ + [gaussian_filter(heatmap, sigma=0.8, radius=(radius, radius), axes=(0, 1)) for heatmap in heatmaps] + for heatmaps in batch_heatmaps + ] + ) + batch_heatmaps = np.clip(batch_heatmaps, 0.001, 50) + batch_heatmaps = np.log(batch_heatmaps) + + batch_heatmaps_pad = np.pad(batch_heatmaps, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="edge").flatten() + + # calculate indices for coordinates + index = coords[..., 0] + 1 + (coords[..., 1] + 1) * (width + 2) + index += (width + 2) * (height + 2) * np.arange(0, batch_size * num_keypoints).reshape(-1, num_keypoints) + index = index.astype(int).reshape(-1, 1) + i_ = batch_heatmaps_pad[index] + ix1 = batch_heatmaps_pad[index + 1] + iy1 = batch_heatmaps_pad[index + width + 2] + ix1y1 = batch_heatmaps_pad[index + width + 3] + ix1_y1_ = batch_heatmaps_pad[index - width - 3] + ix1_ = batch_heatmaps_pad[index - 1] + iy1_ = batch_heatmaps_pad[index - 2 - width] + + # calculate refined coordinates using Newton's method + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1) + derivative = derivative.reshape(num_coords, num_keypoints, 2, 1) + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1) + hessian = hessian.reshape(num_coords, num_keypoints, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + coords -= np.einsum("ijmn,ijnk->ijmk", hessian, derivative).squeeze() + return coords + + +def transform_preds(coords: np.ndarray, center: np.ndarray, scale: np.ndarray, output_size: np.ndarray) -> np.ndarray: + """Get final keypoint predictions from heatmaps and apply scaling and + translation to map them back to the image. + + Note: + num_keypoints: K + + Args: + coords (`np.ndarray` of shape `(num_keypoints, ndims)`): + + * If ndims=2, corrds are predicted keypoint location. + * If ndims=4, corrds are composed of (x, y, scores, tags) + * If ndims=5, corrds are composed of (x, y, scores, tags, + flipped_tags) + + center (`np.ndarray` of shape `(2,)`): + Center of the bounding box (x, y). + scale (`np.ndarray` of shape `(2,)`): + Scale of the bounding box wrt original image of width and height. + output_size (`np.ndarray` of shape `(2,)`): + Size of the destination heatmaps in (height, width) format. + + Returns: + np.ndarray: Predicted coordinates in the images. + """ + if coords.shape[1] not in (2, 4, 5): + raise ValueError("Coordinates need to have either 2, 4 or 5 dimensions.") + if len(center) != 2: + raise ValueError("Center needs to have 2 elements, one for x and one for y.") + if len(scale) != 2: + raise ValueError("Scale needs to consist of a width and height") + if len(output_size) != 2: + raise ValueError("Output size needs to consist of a height and width") + + # Recover the scale which is normalized by a factor of 200. + scale = scale * 200.0 + + # We use unbiased data processing + scale_y = scale[1] / (output_size[0] - 1.0) + scale_x = scale[0] / (output_size[1] - 1.0) + + target_coords = np.ones_like(coords) + target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5 + target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5 + + return target_coords + + +def get_warp_matrix(theta: float, size_input: np.ndarray, size_dst: np.ndarray, size_target: np.ndarray): + """ + Calculate the transformation matrix under the constraint of unbiased. Paper ref: Huang et al. The Devil is in the + Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020). + + Source: https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py + + Args: + theta (`float`): + Rotation angle in degrees. + size_input (`np.ndarray`): + Size of input image [width, height]. + size_dst (`np.ndarray`): + Size of output image [width, height]. + size_target (`np.ndarray`): + Size of ROI in input plane [w, h]. + + Returns: + `np.ndarray`: A matrix for transformation. + """ + theta = np.deg2rad(theta) + matrix = np.zeros((2, 3), dtype=np.float32) + scale_x = size_dst[0] / size_target[0] + scale_y = size_dst[1] / size_target[1] + matrix[0, 0] = math.cos(theta) * scale_x + matrix[0, 1] = -math.sin(theta) * scale_x + matrix[0, 2] = scale_x * ( + -0.5 * size_input[0] * math.cos(theta) + 0.5 * size_input[1] * math.sin(theta) + 0.5 * size_target[0] + ) + matrix[1, 0] = math.sin(theta) * scale_y + matrix[1, 1] = math.cos(theta) * scale_y + matrix[1, 2] = scale_y * ( + -0.5 * size_input[0] * math.sin(theta) - 0.5 * size_input[1] * math.cos(theta) + 0.5 * size_target[1] + ) + return matrix + + +def scipy_warp_affine(src, M, size): + """ + This function implements cv2.warpAffine function using affine_transform in scipy. See https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html and https://docs.opencv.org/4.x/d4/d61/tutorial_warp_affine.html for more details. + + Note: the original implementation of cv2.warpAffine uses cv2.INTER_LINEAR. + """ + channels = [src[..., i] for i in range(src.shape[-1])] + + # Convert to a 3x3 matrix used by SciPy + M_scipy = np.vstack([M, [0, 0, 1]]) + # If you have a matrix for the ‘push’ transformation, use its inverse (numpy.linalg.inv) in this function. + M_inv = inv(M_scipy) + M_inv[0, 0], M_inv[0, 1], M_inv[1, 0], M_inv[1, 1], M_inv[0, 2], M_inv[1, 2] = ( + M_inv[1, 1], + M_inv[1, 0], + M_inv[0, 1], + M_inv[0, 0], + M_inv[1, 2], + M_inv[0, 2], + ) + + new_src = [affine_transform(channel, M_inv, output_shape=size, order=1) for channel in channels] + new_src = np.stack(new_src, axis=-1) + return new_src + + +class VitPoseImageProcessor(BaseImageProcessor): + r""" + Constructs a VitPose image processor. + + Args: + do_affine_transform (`bool`, *optional*, defaults to `True`): + Whether to apply an affine transformation to the input images. + size (`Dict[str, int]` *optional*, defaults to `{"height": 256, "width": 192}`): + Resolution of the image after `affine_transform` is applied. Only has an effect if `do_affine_transform` is set to `True`. Can + be overriden by `size` in the `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). + rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): + Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess` + method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with mean and standard deviation. + image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`, *optional*): + The sequence of means for each channel, to be used when normalizing images. + image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`, *optional*): + The sequence of standard deviations for each channel, to be used when normalizing images. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_affine_transform: bool = True, + size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.do_affine_transform = do_affine_transform + self.size = size if size is not None else {"height": 256, "width": 192} + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.normalize_factor = 200.0 + + def affine_transform( + self, + image: np.array, + center: Tuple[float], + scale: Tuple[float], + rotation: float, + size: Dict[str, int], + data_format: Optional[ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.array: + """ + Apply an affine transformation to an image. + + Args: + image (`np.array`): + Image to transform. + center (`Tuple[float]`): + Center of the bounding box (x, y). + scale (`Tuple[float]`): + Scale of the bounding box with respect to height/width. + rotation (`float`): + Rotation angle in degrees. + size (`Dict[str, int]`): + Size of the destination image. + data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format of the output image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. + """ + + data_format = input_data_format if data_format is None else data_format + + size = (size["width"], size["height"]) + + # one uses a pixel standard deviation of 200 pixels + transformation = get_warp_matrix(rotation, center * 2.0, np.array(size) - 1.0, scale * 200.0) + + # input image requires channels last format + image = ( + image + if input_data_format == ChannelDimension.LAST + else to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format) + ) + image = scipy_warp_affine(src=image, M=transformation, size=(size[1], size[0])) + + image = to_channel_dimension_format(image, data_format, ChannelDimension.LAST) + + return image + + def preprocess( + self, + images: ImageInput, + boxes: Union[List[List[float]], np.ndarray], + do_affine_transform: Optional[bool] = None, + size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + + boxes (`List[List[List[float]]]` or `np.ndarray`): + List or array of bounding boxes for each image. Each box should be a list of 4 floats representing the bounding + box coordinates in COCO format (top_left_x, top_left_y, width, height). + + do_affine_transform (`bool`, *optional*, defaults to `self.do_affine_transform`): + Whether to apply an affine transformation to the input images. + size (`Dict[str, int]` *optional*, defaults to `self.size`): + Dictionary in the format `{"height": h, "width": w}` specifying the size of the output image after + resizing. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between [0 - 1]. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean to use if `do_normalize` is set to `True`. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation to use if `do_normalize` is set to `True`. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, + width). + """ + do_affine_transform = do_affine_transform if do_affine_transform is not None else self.do_affine_transform + size = size if size is not None else self.size + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + if isinstance(boxes, list) and len(images) != len(boxes): + raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {len(boxes)}") + elif isinstance(boxes, np.ndarray) and len(images) != boxes.shape[0]: + raise ValueError(f"Batch of images and boxes mismatch : {len(images)} != {boxes.shape[0]}") + + # All transformations expect numpy arrays. + images = [to_numpy_array(image) for image in images] + + if is_scaled_image(images[0]) and do_rescale: + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + # We assume that all images have the same channel dimension format. + input_data_format = infer_channel_dimension_format(images[0]) + + # transformations (affine transformation + rescaling + normalization) + if self.do_affine_transform: + new_images = [] + for image, image_boxes in zip(images, boxes): + for box in image_boxes: + center, scale = box_to_center_and_scale( + box, + image_width=size["width"], + image_height=size["height"], + normalize_factor=self.normalize_factor, + ) + transformed_image = self.affine_transform( + image, center, scale, rotation=0, size=size, input_data_format=input_data_format + ) + new_images.append(transformed_image) + images = new_images + + # For batch processing, the number of boxes must be consistent across all images in the batch. + # When using a list input, the number of boxes can vary dynamically per image. + # The image processor creates pixel_values of shape (batch_size*num_persons, num_channels, height, width) + + all_images = [] + for image in images: + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize( + image=image, mean=image_mean, std=image_std, input_data_format=input_data_format + ) + + all_images.append(image) + images = [ + to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + for image in all_images + ] + + data = {"pixel_values": images} + encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) + + return encoded_inputs + + def keypoints_from_heatmaps( + self, + heatmaps: np.ndarray, + center: np.ndarray, + scale: np.ndarray, + kernel: int = 11, + ): + """ + Get final keypoint predictions from heatmaps and transform them back to + the image. + + Args: + heatmaps (`np.ndarray` of shape `(batch_size, num_keypoints, height, width])`): + Model predicted heatmaps. + center (`np.ndarray` of shape `(batch_size, 2)`): + Center of the bounding box (x, y). + scale (`np.ndarray` of shape `(batch_size, 2)`): + Scale of the bounding box wrt original images of width and height. + kernel (int, *optional*, defaults to 11): + Gaussian kernel size (K) for modulation, which should match the heatmap gaussian sigma when training. + K=17 for sigma=3 and k=11 for sigma=2. + + Returns: + tuple: A tuple containing keypoint predictions and scores. + + - preds (`np.ndarray` of shape `(batch_size, num_keypoints, 2)`): + Predicted keypoint location in images. + - scores (`np.ndarray` of shape `(batch_size, num_keypoints, 1)`): + Scores (confidence) of the keypoints. + """ + batch_size, _, height, width = heatmaps.shape + + coords, scores = get_keypoint_predictions(heatmaps) + + preds = post_dark_unbiased_data_processing(coords, heatmaps, kernel=kernel) + + # Transform back to the image + for i in range(batch_size): + preds[i] = transform_preds(preds[i], center=center[i], scale=scale[i], output_size=[height, width]) + + return preds, scores + + def post_process_pose_estimation( + self, + outputs: "VitPoseEstimatorOutput", + boxes: Union[List[List[List[float]]], np.ndarray], + kernel_size: int = 11, + threshold: Optional[float] = None, + target_sizes: Union[TensorType, List[Tuple]] = None, + ): + """ + Transform the heatmaps into keypoint predictions and transform them back to the image. + + Args: + outputs (`VitPoseEstimatorOutput`): + VitPoseForPoseEstimation model outputs. + boxes (`List[List[List[float]]]` or `np.ndarray`): + List or array of bounding boxes for each image. Each box should be a list of 4 floats representing the bounding + box coordinates in COCO format (top_left_x, top_left_y, width, height). + kernel_size (`int`, *optional*, defaults to 11): + Gaussian kernel size (K) for modulation. + threshold (`float`, *optional*, defaults to None): + Score threshold to keep object detection predictions. + target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + `(height, width)` of each image in the batch. If unset, predictions will be resize with the default value. + Returns: + `List[List[Dict]]`: A list of dictionaries, each dictionary containing the keypoints and boxes for an image + in the batch as predicted by the model. + """ + + # First compute centers and scales for each bounding box + batch_size, num_keypoints, _, _ = outputs.heatmaps.shape + + if target_sizes is not None: + if batch_size != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + + centers = np.zeros((batch_size, 2), dtype=np.float32) + scales = np.zeros((batch_size, 2), dtype=np.float32) + flattened_boxes = list(itertools.chain(*boxes)) + for i in range(batch_size): + if target_sizes is not None: + image_width, image_height = target_sizes[i][0], target_sizes[i][1] + scale_factor = np.array([image_width, image_height, image_width, image_height]) + flattened_boxes[i] = flattened_boxes[i] * scale_factor + width, height = self.size["width"], self.size["height"] + center, scale = box_to_center_and_scale(flattened_boxes[i], image_width=width, image_height=height) + centers[i, :] = center + scales[i, :] = scale + + preds, scores = self.keypoints_from_heatmaps( + outputs.heatmaps.cpu().numpy(), centers, scales, kernel=kernel_size + ) + + all_boxes = np.zeros((batch_size, 4), dtype=np.float32) + all_boxes[:, 0:2] = centers[:, 0:2] + all_boxes[:, 2:4] = scales[:, 0:2] + + poses = torch.tensor(preds) + scores = torch.tensor(scores) + labels = torch.arange(0, num_keypoints) + bboxes_xyxy = torch.tensor(coco_to_pascal_voc(all_boxes)) + + results: List[List[Dict[str, torch.Tensor]]] = [] + + pose_bbox_pairs = zip(poses, scores, bboxes_xyxy) + + for image_bboxes in boxes: + image_results: List[Dict[str, torch.Tensor]] = [] + for _ in image_bboxes: + # Unpack the next pose and bbox_xyxy from the iterator + pose, score, bbox_xyxy = next(pose_bbox_pairs) + score = score.squeeze() + keypoints_labels = labels + if threshold is not None: + keep = score > threshold + pose = pose[keep] + score = score[keep] + keypoints_labels = keypoints_labels[keep] + pose_result = {"keypoints": pose, "scores": score, "labels": keypoints_labels, "bbox": bbox_xyxy} + image_results.append(pose_result) + results.append(image_results) + + return results + + +__all__ = ["VitPoseImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/vitpose/modeling_vitpose.py b/docs/transformers/build/lib/transformers/models/vitpose/modeling_vitpose.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe9738abf572d8d6224fb2232059794e4f4a6a1 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitpose/modeling_vitpose.py @@ -0,0 +1,340 @@ +# coding=utf-8 +# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch VitPose model.""" + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.backbone_utils import load_backbone +from .configuration_vitpose import VitPoseConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "VitPoseConfig" + + +@dataclass +class VitPoseEstimatorOutput(ModelOutput): + """ + Class for outputs of pose estimation models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Loss is not supported at this moment. See https://github.com/ViTAE-Transformer/ViTPose/tree/main/mmpose/models/losses for further detail. + heatmaps (`torch.FloatTensor` of shape `(batch_size, num_keypoints, height, width)`): + Heatmaps as predicted by the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states + (also called feature maps) of the model at the output of each stage. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + heatmaps: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +class VitPosePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = VitPoseConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +VITPOSE_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`VitPoseConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VITPOSE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`VitPoseImageProcessor`]. See + [`VitPoseImageProcessor.__call__`] for details. + + dataset_index (`torch.Tensor` of shape `(batch_size,)`): + Index to use in the Mixture-of-Experts (MoE) blocks of the backbone. + + This corresponds to the dataset index used during training, e.g. For the single dataset index 0 refers to the corresponding dataset. For the multiple datasets index 0 refers to dataset A (e.g. MPII) and index 1 refers to dataset B (e.g. CrowdPose). + + flip_pairs (`torch.tensor`, *optional*): + Whether to mirror pairs of keypoints (for example, left ear -- right ear). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def flip_back(output_flipped, flip_pairs, target_type="gaussian-heatmap"): + """Flip the flipped heatmaps back to the original form. + + Args: + output_flipped (`torch.tensor` of shape `(batch_size, num_keypoints, height, width)`): + The output heatmaps obtained from the flipped images. + flip_pairs (`torch.Tensor` of shape `(num_keypoints, 2)`): + Pairs of keypoints which are mirrored (for example, left ear -- right ear). + target_type (`str`, *optional*, defaults to `"gaussian-heatmap"`): + Target type to use. Can be gaussian-heatmap or combined-target. + gaussian-heatmap: Classification target with gaussian distribution. + combined-target: The combination of classification target (response map) and regression target (offset map). + Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020). + + Returns: + torch.Tensor: heatmaps that flipped back to the original image + """ + if target_type not in ["gaussian-heatmap", "combined-target"]: + raise ValueError("target_type should be gaussian-heatmap or combined-target") + + if output_flipped.ndim != 4: + raise ValueError("output_flipped should be [batch_size, num_keypoints, height, width]") + batch_size, num_keypoints, height, width = output_flipped.shape + channels = 1 + if target_type == "combined-target": + channels = 3 + output_flipped[:, 1::3, ...] = -output_flipped[:, 1::3, ...] + output_flipped = output_flipped.reshape(batch_size, -1, channels, height, width) + output_flipped_back = output_flipped.clone() + + # Swap left-right parts + for left, right in flip_pairs.tolist(): + output_flipped_back[:, left, ...] = output_flipped[:, right, ...] + output_flipped_back[:, right, ...] = output_flipped[:, left, ...] + output_flipped_back = output_flipped_back.reshape((batch_size, num_keypoints, height, width)) + # Flip horizontally + output_flipped_back = output_flipped_back.flip(-1) + return output_flipped_back + + +class VitPoseSimpleDecoder(nn.Module): + """ + Simple decoding head consisting of a ReLU activation, 4x upsampling and a 3x3 convolution, turning the + feature maps into heatmaps. + """ + + def __init__(self, config) -> None: + super().__init__() + + self.activation = nn.ReLU() + self.upsampling = nn.Upsample(scale_factor=config.scale_factor, mode="bilinear", align_corners=False) + self.conv = nn.Conv2d( + config.backbone_config.hidden_size, config.num_labels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None) -> torch.Tensor: + # Transform input: ReLU + upsample + hidden_state = self.activation(hidden_state) + hidden_state = self.upsampling(hidden_state) + heatmaps = self.conv(hidden_state) + + if flip_pairs is not None: + heatmaps = flip_back(heatmaps, flip_pairs) + + return heatmaps + + +class VitPoseClassicDecoder(nn.Module): + """ + Classic decoding head consisting of a 2 deconvolutional blocks, followed by a 1x1 convolution layer, + turning the feature maps into heatmaps. + """ + + def __init__(self, config: VitPoseConfig): + super().__init__() + + self.deconv1 = nn.ConvTranspose2d( + config.backbone_config.hidden_size, 256, kernel_size=4, stride=2, padding=1, bias=False + ) + self.batchnorm1 = nn.BatchNorm2d(256) + self.relu1 = nn.ReLU() + + self.deconv2 = nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False) + self.batchnorm2 = nn.BatchNorm2d(256) + self.relu2 = nn.ReLU() + + self.conv = nn.Conv2d(256, config.num_labels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_state: torch.Tensor, flip_pairs: Optional[torch.Tensor] = None): + hidden_state = self.deconv1(hidden_state) + hidden_state = self.batchnorm1(hidden_state) + hidden_state = self.relu1(hidden_state) + + hidden_state = self.deconv2(hidden_state) + hidden_state = self.batchnorm2(hidden_state) + hidden_state = self.relu2(hidden_state) + + heatmaps = self.conv(hidden_state) + + if flip_pairs is not None: + heatmaps = flip_back(heatmaps, flip_pairs) + + return heatmaps + + +@add_start_docstrings( + "The VitPose model with a pose estimation head on top.", + VITPOSE_START_DOCSTRING, +) +class VitPoseForPoseEstimation(VitPosePreTrainedModel): + def __init__(self, config: VitPoseConfig) -> None: + super().__init__(config) + + self.backbone = load_backbone(config) + + # add backbone attributes + if not hasattr(self.backbone.config, "hidden_size"): + raise ValueError("The backbone should have a hidden_size attribute") + if not hasattr(self.backbone.config, "image_size"): + raise ValueError("The backbone should have an image_size attribute") + if not hasattr(self.backbone.config, "patch_size"): + raise ValueError("The backbone should have a patch_size attribute") + + self.head = VitPoseSimpleDecoder(config) if config.use_simple_decoder else VitPoseClassicDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VITPOSE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=VitPoseEstimatorOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + dataset_index: Optional[torch.Tensor] = None, + flip_pairs: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[tuple, VitPoseEstimatorOutput]: + """ + Returns: + + Examples: + + ```python + >>> from transformers import AutoImageProcessor, VitPoseForPoseEstimation + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> processor = AutoImageProcessor.from_pretrained("usyd-community/vitpose-base-simple") + >>> model = VitPoseForPoseEstimation.from_pretrained("usyd-community/vitpose-base-simple") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> boxes = [[[412.8, 157.61, 53.05, 138.01], [384.43, 172.21, 15.12, 35.74]]] + >>> inputs = processor(image, boxes=boxes, return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> heatmaps = outputs.heatmaps + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + loss = None + if labels is not None: + raise NotImplementedError("Training is not yet supported") + + outputs = self.backbone.forward_with_filtered_kwargs( + pixel_values, + dataset_index=dataset_index, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + # Turn output hidden states in tensor of shape (batch_size, num_channels, height, width) + sequence_output = outputs.feature_maps[-1] if return_dict else outputs[0][-1] + batch_size = sequence_output.shape[0] + patch_height = self.config.backbone_config.image_size[0] // self.config.backbone_config.patch_size[0] + patch_width = self.config.backbone_config.image_size[1] // self.config.backbone_config.patch_size[1] + sequence_output = ( + sequence_output.permute(0, 2, 1).reshape(batch_size, -1, patch_height, patch_width).contiguous() + ) + + heatmaps = self.head(sequence_output, flip_pairs=flip_pairs) + + if not return_dict: + if output_hidden_states: + output = (heatmaps,) + outputs[1:] + else: + output = (heatmaps,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return VitPoseEstimatorOutput( + loss=loss, + heatmaps=heatmaps, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["VitPosePreTrainedModel", "VitPoseForPoseEstimation"] diff --git a/docs/transformers/build/lib/transformers/models/vitpose_backbone/__init__.py b/docs/transformers/build/lib/transformers/models/vitpose_backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..858e93797d26948970f209e6429e4a586a5044a6 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitpose_backbone/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vitpose_backbone import * + from .modeling_vitpose_backbone import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py b/docs/transformers/build/lib/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..439e596273dbc3739b0b8f1cf7621e427571bc4c --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitpose_backbone/configuration_vitpose_backbone.py @@ -0,0 +1,139 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VitPose backbone configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging +from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices + + +logger = logging.get_logger(__name__) + + +class VitPoseBackboneConfig(BackboneConfigMixin, PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VitPoseBackbone`]. It is used to instantiate a + VitPose model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the VitPose + [usyd-community/vitpose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to `[256, 192]`): + The size (resolution) of each image. + patch_size (`List[int]`, *optional*, defaults to `[16, 16]`): + The size (resolution) of each patch. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + mlp_ratio (`int`, *optional*, defaults to 4): + The ratio of the hidden size in the feedforward network to the hidden size in the attention layers. + num_experts (`int`, *optional*, defaults to 1): + The number of experts in the MoE layer. + part_features (`int`, *optional*): + The number of part features to output. Only used in case `num_experts` is greater than 1. + hidden_act (`str`, *optional*, defaults to `"gelu"`): + The non-linear activation function in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + out_features (`List[str]`, *optional*): + If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc. + (depending on how many stages the model has). If unset and `out_indices` is set, will default to the + corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + out_indices (`List[int]`, *optional*): + If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how + many stages the model has). If unset and `out_features` is set, will default to the corresponding stages. + If unset and `out_features` is unset, will default to the last stage. Must be in the + same order as defined in the `stage_names` attribute. + + Example: + + ```python + >>> from transformers import VitPoseBackboneConfig, VitPoseBackbone + + >>> # Initializing a VitPose configuration + >>> configuration = VitPoseBackboneConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = VitPoseBackbone(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vitpose_backbone" + + def __init__( + self, + image_size=[256, 192], + patch_size=[16, 16], + num_channels=3, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + mlp_ratio=4, + num_experts=1, + part_features=256, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-12, + qkv_bias=True, + out_features=None, + out_indices=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.num_experts = num_experts + self.part_features = part_features + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, num_hidden_layers + 1)] + self._out_features, self._out_indices = get_aligned_output_features_output_indices( + out_features=out_features, out_indices=out_indices, stage_names=self.stage_names + ) + + +__all__ = ["VitPoseBackboneConfig"] diff --git a/docs/transformers/build/lib/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py b/docs/transformers/build/lib/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..8abce8dfdc3d84bc35ea3ea0a9b4575aa6e57312 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vitpose_backbone/modeling_vitpose_backbone.py @@ -0,0 +1,579 @@ +# coding=utf-8 +# Copyright 2024 University of Sydney and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch VitPose backbone model. + +This code is the same as the original Vision Transformer (ViT) with 2 modifications: +- use of padding=2 in the patch embedding layer +- addition of a mixture-of-experts MLP layer +""" + +import collections.abc +from typing import Callable, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...modeling_outputs import BackboneOutput, BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from ...utils.backbone_utils import BackboneMixin +from .configuration_vitpose_backbone import VitPoseBackboneConfig + + +logger = logging.get_logger(__name__) + +# General docstring +_CONFIG_FOR_DOC = "VitPoseBackboneConfig" + + +class VitPoseBackbonePatchEmbeddings(nn.Module): + """Image to Patch Embedding.""" + + def __init__(self, config): + super().__init__() + + image_size = config.image_size + patch_size = config.patch_size + num_channels = config.num_channels + embed_dim = config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size, padding=2) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + height, width = pixel_values.shape[-2:] + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values) + + embeddings = embeddings.flatten(2).transpose(1, 2) + return embeddings + + +class VitPoseBackboneEmbeddings(nn.Module): + """ + Construct the position and patch embeddings. + """ + + def __init__(self, config: VitPoseBackboneConfig) -> None: + super().__init__() + + self.patch_embeddings = VitPoseBackbonePatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + embeddings = self.patch_embeddings(pixel_values) + + # add positional encoding to each token + embeddings = embeddings + self.position_embeddings[:, 1:] + self.position_embeddings[:, :1] + + embeddings = self.dropout(embeddings) + + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->VitPoseBackbone +class VitPoseBackboneSelfAttention(nn.Module): + def __init__(self, config: VitPoseBackboneConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VitPoseBackbone +class VitPoseBackboneSelfOutput(nn.Module): + """ + The residual connection is defined in VitPoseBackboneLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: VitPoseBackboneConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VitPoseBackbone +class VitPoseBackboneAttention(nn.Module): + def __init__(self, config: VitPoseBackboneConfig) -> None: + super().__init__() + self.attention = VitPoseBackboneSelfAttention(config) + self.output = VitPoseBackboneSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class VitPoseBackboneMoeMLP(nn.Module): + def __init__(self, config: VitPoseBackboneConfig): + super().__init__() + + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + + num_experts = config.num_experts + part_features = config.part_features + + self.part_features = part_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = ACT2FN[config.hidden_act] + self.fc2 = nn.Linear(hidden_features, out_features - part_features) + self.drop = nn.Dropout(config.hidden_dropout_prob) + + self.num_experts = num_experts + experts = [nn.Linear(hidden_features, part_features) for _ in range(num_experts)] + self.experts = nn.ModuleList(experts) + + def forward(self, hidden_state: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + expert_hidden_state = torch.zeros_like(hidden_state[:, :, -self.part_features :]) + + hidden_state = self.fc1(hidden_state) + hidden_state = self.act(hidden_state) + shared_hidden_state = self.fc2(hidden_state) + indices = indices.view(-1, 1, 1) + + # to support ddp training + for i in range(self.num_experts): + selected_index = indices == i + current_hidden_state = self.experts[i](hidden_state) * selected_index + expert_hidden_state = expert_hidden_state + current_hidden_state + + hidden_state = torch.cat([shared_hidden_state, expert_hidden_state], dim=-1) + + return hidden_state + + +class VitPoseBackboneMLP(nn.Module): + def __init__(self, config: VitPoseBackboneConfig) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = nn.Linear(in_features, hidden_features, bias=True) + self.activation = ACT2FN[config.hidden_act] + self.fc2 = nn.Linear(hidden_features, out_features, bias=True) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class VitPoseBackboneLayer(nn.Module): + def __init__(self, config: VitPoseBackboneConfig) -> None: + super().__init__() + self.num_experts = config.num_experts + self.attention = VitPoseBackboneAttention(config) + self.mlp = VitPoseBackboneMLP(config) if self.num_experts == 1 else VitPoseBackboneMoeMLP(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + dataset_index: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + # Validate dataset_index when using multiple experts + if self.num_experts > 1 and dataset_index is None: + raise ValueError( + "dataset_index must be provided when using multiple experts " + f"(num_experts={self.num_experts}). Please provide dataset_index " + "to the forward pass." + ) + self_attention_outputs = self.attention( + self.layernorm_before(hidden_states), # in VitPoseBackbone, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + layer_output = self.layernorm_after(hidden_states) + if self.num_experts == 1: + layer_output = self.mlp(layer_output) + else: + layer_output = self.mlp(layer_output, indices=dataset_index) + + # second residual connection + layer_output = layer_output + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VitPoseBackbone +class VitPoseBackboneEncoder(nn.Module): + def __init__(self, config: VitPoseBackboneConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([VitPoseBackboneLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + dataset_index: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + dataset_index, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, dataset_index, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class VitPoseBackbonePreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = VitPoseBackboneConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["VitPoseBackboneEmbeddings", "VitPoseBackboneLayer"] + _supports_sdpa = True + _supports_flash_attn_2 = True + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm, VitPoseBackboneEmbeddings]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, VitPoseBackboneEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + +VITPOSE_BACKBONE_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`VitPoseBackboneConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VITPOSE_BACKBONE_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. + + dataset_index (`torch.Tensor` of shape `(batch_size,)`): + Index to use in the Mixture-of-Experts (MoE) blocks of the backbone. + + This corresponds to the dataset index used during training, e.g. index 0 refers to COCO. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The VitPose backbone useful for downstream tasks.", + VITPOSE_BACKBONE_START_DOCSTRING, +) +class VitPoseBackbone(VitPoseBackbonePreTrainedModel, BackboneMixin): + def __init__(self, config: VitPoseBackboneConfig): + super().__init__(config) + super()._init_backbone(config) + + self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)] + self.embeddings = VitPoseBackboneEmbeddings(config) + self.encoder = VitPoseBackboneEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VITPOSE_BACKBONE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BackboneOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: torch.Tensor, + dataset_index: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + """ + Returns: + + Examples: + + ```python + >>> from transformers import VitPoseBackboneConfig, VitPoseBackbone + >>> import torch + + >>> config = VitPoseBackboneConfig(out_indices=[-1]) + >>> model = VitPoseBackbone(config) + + >>> pixel_values = torch.randn(1, 3, 256, 192) + >>> dataset_index = torch.tensor([1]) + >>> outputs = model(pixel_values, dataset_index) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values) + + outputs = self.encoder( + embedding_output, + dataset_index=dataset_index, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + hidden_states = outputs.hidden_states if return_dict else outputs[1] + + feature_maps = () + for stage, hidden_state in zip(self.stage_names, hidden_states): + if stage in self.out_features: + hidden_state = self.layernorm(hidden_state) + feature_maps += (hidden_state,) + + if not return_dict: + if output_hidden_states: + output = (feature_maps,) + outputs[1:] + else: + output = (feature_maps,) + outputs[2:] + return output + + return BackboneOutput( + feature_maps=feature_maps, + hidden_states=outputs.hidden_states if output_hidden_states else None, + attentions=outputs.attentions, + ) + + +__all__ = ["VitPoseBackbonePreTrainedModel", "VitPoseBackbone"] diff --git a/docs/transformers/build/lib/transformers/models/vits/__init__.py b/docs/transformers/build/lib/transformers/models/vits/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf54b7884848784ac7e94062be8da91f4fedcf7 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vits/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vits import * + from .modeling_vits import * + from .tokenization_vits import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/vits/configuration_vits.py b/docs/transformers/build/lib/transformers/models/vits/configuration_vits.py new file mode 100644 index 0000000000000000000000000000000000000000..6de2591b0f3addd9536e0ed81023251666764163 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vits/configuration_vits.py @@ -0,0 +1,253 @@ +# coding=utf-8 +# Copyright 2023 The Kakao Enterprise Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""VITS model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class VitsConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VitsModel`]. It is used to instantiate a VITS + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the VITS + [facebook/mms-tts-eng](https://huggingface.co/facebook/mms-tts-eng) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 38): + Vocabulary size of the VITS model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed to the forward method of [`VitsModel`]. + hidden_size (`int`, *optional*, defaults to 192): + Dimensionality of the text encoder layers. + num_hidden_layers (`int`, *optional*, defaults to 6): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 2): + Number of attention heads for each attention layer in the Transformer encoder. + window_size (`int`, *optional*, defaults to 4): + Window size for the relative positional embeddings in the attention layers of the Transformer encoder. + use_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the key, query, value projection layers in the Transformer encoder. + ffn_dim (`int`, *optional*, defaults to 768): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) + for more details. + ffn_kernel_size (`int`, *optional*, defaults to 3): + Kernel size of the 1D convolution layers used by the feed-forward network in the Transformer encoder. + flow_size (`int`, *optional*, defaults to 192): + Dimensionality of the flow layers. + spectrogram_bins (`int`, *optional*, defaults to 513): + Number of frequency bins in the target spectrogram. + hidden_act (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings and encoder. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + use_stochastic_duration_prediction (`bool`, *optional*, defaults to `True`): + Whether to use the stochastic duration prediction module or the regular duration predictor. + num_speakers (`int`, *optional*, defaults to 1): + Number of speakers if this is a multi-speaker model. + speaker_embedding_size (`int`, *optional*, defaults to 0): + Number of channels used by the speaker embeddings. Is zero for single-speaker models. + upsample_initial_channel (`int`, *optional*, defaults to 512): + The number of input channels into the HiFi-GAN upsampling network. + upsample_rates (`Tuple[int]` or `List[int]`, *optional*, defaults to `[8, 8, 2, 2]`): + A tuple of integers defining the stride of each 1D convolutional layer in the HiFi-GAN upsampling network. + The length of `upsample_rates` defines the number of convolutional layers and has to match the length of + `upsample_kernel_sizes`. + upsample_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[16, 16, 4, 4]`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the HiFi-GAN upsampling + network. The length of `upsample_kernel_sizes` defines the number of convolutional layers and has to match + the length of `upsample_rates`. + resblock_kernel_sizes (`Tuple[int]` or `List[int]`, *optional*, defaults to `[3, 7, 11]`): + A tuple of integers defining the kernel sizes of the 1D convolutional layers in the HiFi-GAN + multi-receptive field fusion (MRF) module. + resblock_dilation_sizes (`Tuple[Tuple[int]]` or `List[List[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`): + A nested tuple of integers defining the dilation rates of the dilated 1D convolutional layers in the + HiFi-GAN multi-receptive field fusion (MRF) module. + leaky_relu_slope (`float`, *optional*, defaults to 0.1): + The angle of the negative slope used by the leaky ReLU activation. + depth_separable_channels (`int`, *optional*, defaults to 2): + Number of channels to use in each depth-separable block. + depth_separable_num_layers (`int`, *optional*, defaults to 3): + Number of convolutional layers to use in each depth-separable block. + duration_predictor_flow_bins (`int`, *optional*, defaults to 10): + Number of channels to map using the unonstrained rational spline in the duration predictor model. + duration_predictor_tail_bound (`float`, *optional*, defaults to 5.0): + Value of the tail bin boundary when computing the unconstrained rational spline in the duration predictor + model. + duration_predictor_kernel_size (`int`, *optional*, defaults to 3): + Kernel size of the 1D convolution layers used in the duration predictor model. + duration_predictor_dropout (`float`, *optional*, defaults to 0.5): + The dropout ratio for the duration predictor model. + duration_predictor_num_flows (`int`, *optional*, defaults to 4): + Number of flow stages used by the duration predictor model. + duration_predictor_filter_channels (`int`, *optional*, defaults to 256): + Number of channels for the convolution layers used in the duration predictor model. + prior_encoder_num_flows (`int`, *optional*, defaults to 4): + Number of flow stages used by the prior encoder flow model. + prior_encoder_num_wavenet_layers (`int`, *optional*, defaults to 4): + Number of WaveNet layers used by the prior encoder flow model. + posterior_encoder_num_wavenet_layers (`int`, *optional*, defaults to 16): + Number of WaveNet layers used by the posterior encoder model. + wavenet_kernel_size (`int`, *optional*, defaults to 5): + Kernel size of the 1D convolution layers used in the WaveNet model. + wavenet_dilation_rate (`int`, *optional*, defaults to 1): + Dilation rates of the dilated 1D convolutional layers used in the WaveNet model. + wavenet_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the WaveNet layers. + speaking_rate (`float`, *optional*, defaults to 1.0): + Speaking rate. Larger values give faster synthesised speech. + noise_scale (`float`, *optional*, defaults to 0.667): + How random the speech prediction is. Larger values create more variation in the predicted speech. + noise_scale_duration (`float`, *optional*, defaults to 0.8): + How random the duration prediction is. Larger values create more variation in the predicted durations. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the output audio waveform is digitalized expressed in hertz (Hz). + + Example: + + ```python + >>> from transformers import VitsModel, VitsConfig + + >>> # Initializing a "facebook/mms-tts-eng" style configuration + >>> configuration = VitsConfig() + + >>> # Initializing a model (with random weights) from the "facebook/mms-tts-eng" style configuration + >>> model = VitsModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vits" + + def __init__( + self, + vocab_size=38, + hidden_size=192, + num_hidden_layers=6, + num_attention_heads=2, + window_size=4, + use_bias=True, + ffn_dim=768, + layerdrop=0.1, + ffn_kernel_size=3, + flow_size=192, + spectrogram_bins=513, + hidden_act="relu", + hidden_dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + use_stochastic_duration_prediction=True, + num_speakers=1, + speaker_embedding_size=0, + upsample_initial_channel=512, + upsample_rates=[8, 8, 2, 2], + upsample_kernel_sizes=[16, 16, 4, 4], + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + leaky_relu_slope=0.1, + depth_separable_channels=2, + depth_separable_num_layers=3, + duration_predictor_flow_bins=10, + duration_predictor_tail_bound=5.0, + duration_predictor_kernel_size=3, + duration_predictor_dropout=0.5, + duration_predictor_num_flows=4, + duration_predictor_filter_channels=256, + prior_encoder_num_flows=4, + prior_encoder_num_wavenet_layers=4, + posterior_encoder_num_wavenet_layers=16, + wavenet_kernel_size=5, + wavenet_dilation_rate=1, + wavenet_dropout=0.0, + speaking_rate=1.0, + noise_scale=0.667, + noise_scale_duration=0.8, + sampling_rate=16_000, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.window_size = window_size + self.use_bias = use_bias + self.ffn_dim = ffn_dim + self.layerdrop = layerdrop + self.ffn_kernel_size = ffn_kernel_size + self.flow_size = flow_size + self.spectrogram_bins = spectrogram_bins + self.hidden_act = hidden_act + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.use_stochastic_duration_prediction = use_stochastic_duration_prediction + self.num_speakers = num_speakers + self.speaker_embedding_size = speaker_embedding_size + self.upsample_initial_channel = upsample_initial_channel + self.upsample_rates = upsample_rates + self.upsample_kernel_sizes = upsample_kernel_sizes + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.leaky_relu_slope = leaky_relu_slope + self.depth_separable_channels = depth_separable_channels + self.depth_separable_num_layers = depth_separable_num_layers + self.duration_predictor_flow_bins = duration_predictor_flow_bins + self.duration_predictor_tail_bound = duration_predictor_tail_bound + self.duration_predictor_kernel_size = duration_predictor_kernel_size + self.duration_predictor_dropout = duration_predictor_dropout + self.duration_predictor_num_flows = duration_predictor_num_flows + self.duration_predictor_filter_channels = duration_predictor_filter_channels + self.prior_encoder_num_flows = prior_encoder_num_flows + self.prior_encoder_num_wavenet_layers = prior_encoder_num_wavenet_layers + self.posterior_encoder_num_wavenet_layers = posterior_encoder_num_wavenet_layers + self.wavenet_kernel_size = wavenet_kernel_size + self.wavenet_dilation_rate = wavenet_dilation_rate + self.wavenet_dropout = wavenet_dropout + self.speaking_rate = speaking_rate + self.noise_scale = noise_scale + self.noise_scale_duration = noise_scale_duration + self.sampling_rate = sampling_rate + + if len(upsample_kernel_sizes) != len(upsample_rates): + raise ValueError( + f"The length of `upsample_kernel_sizes` ({len(upsample_kernel_sizes)}) must match the length of " + f"`upsample_rates` ({len(upsample_rates)})" + ) + + super().__init__(**kwargs) + + +__all__ = ["VitsConfig"] diff --git a/docs/transformers/build/lib/transformers/models/vits/convert_original_checkpoint.py b/docs/transformers/build/lib/transformers/models/vits/convert_original_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7f122e86fa54ac75b8051ef13836e38924f57d1a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vits/convert_original_checkpoint.py @@ -0,0 +1,390 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert VITS checkpoint.""" + +import argparse +import json +import tempfile + +import torch +from huggingface_hub import hf_hub_download + +from transformers import VitsConfig, VitsModel, VitsTokenizer, logging + + +logging.set_verbosity_info() +logger = logging.get_logger("transformers.models.vits") + +MAPPING_TEXT_ENCODER = { + "enc_p.emb": "text_encoder.embed_tokens", + "enc_p.encoder.attn_layers.*.conv_k": "text_encoder.encoder.layers.*.attention.k_proj", + "enc_p.encoder.attn_layers.*.conv_v": "text_encoder.encoder.layers.*.attention.v_proj", + "enc_p.encoder.attn_layers.*.conv_q": "text_encoder.encoder.layers.*.attention.q_proj", + "enc_p.encoder.attn_layers.*.conv_o": "text_encoder.encoder.layers.*.attention.out_proj", + "enc_p.encoder.attn_layers.*.emb_rel_k": "text_encoder.encoder.layers.*.attention.emb_rel_k", + "enc_p.encoder.attn_layers.*.emb_rel_v": "text_encoder.encoder.layers.*.attention.emb_rel_v", + "enc_p.encoder.norm_layers_1.*.gamma": "text_encoder.encoder.layers.*.layer_norm.weight", + "enc_p.encoder.norm_layers_1.*.beta": "text_encoder.encoder.layers.*.layer_norm.bias", + "enc_p.encoder.ffn_layers.*.conv_1": "text_encoder.encoder.layers.*.feed_forward.conv_1", + "enc_p.encoder.ffn_layers.*.conv_2": "text_encoder.encoder.layers.*.feed_forward.conv_2", + "enc_p.encoder.norm_layers_2.*.gamma": "text_encoder.encoder.layers.*.final_layer_norm.weight", + "enc_p.encoder.norm_layers_2.*.beta": "text_encoder.encoder.layers.*.final_layer_norm.bias", + "enc_p.proj": "text_encoder.project", +} +MAPPING_STOCHASTIC_DURATION_PREDICTOR = { + "dp.pre": "duration_predictor.conv_pre", + "dp.proj": "duration_predictor.conv_proj", + "dp.convs.convs_sep.*": "duration_predictor.conv_dds.convs_dilated.*", + "dp.convs.convs_1x1.*": "duration_predictor.conv_dds.convs_pointwise.*", + "dp.convs.norms_1.*.gamma": "duration_predictor.conv_dds.norms_1.*.weight", + "dp.convs.norms_1.*.beta": "duration_predictor.conv_dds.norms_1.*.bias", + "dp.convs.norms_2.*.gamma": "duration_predictor.conv_dds.norms_2.*.weight", + "dp.convs.norms_2.*.beta": "duration_predictor.conv_dds.norms_2.*.bias", + "dp.flows.0.logs": "duration_predictor.flows.0.log_scale", + "dp.flows.0.m": "duration_predictor.flows.0.translate", + "dp.flows.*.pre": "duration_predictor.flows.*.conv_pre", + "dp.flows.*.proj": "duration_predictor.flows.*.conv_proj", + "dp.flows.*.convs.convs_1x1.0": "duration_predictor.flows.*.conv_dds.convs_pointwise.0", + "dp.flows.*.convs.convs_1x1.1": "duration_predictor.flows.*.conv_dds.convs_pointwise.1", + "dp.flows.*.convs.convs_1x1.2": "duration_predictor.flows.*.conv_dds.convs_pointwise.2", + "dp.flows.*.convs.convs_sep.0": "duration_predictor.flows.*.conv_dds.convs_dilated.0", + "dp.flows.*.convs.convs_sep.1": "duration_predictor.flows.*.conv_dds.convs_dilated.1", + "dp.flows.*.convs.convs_sep.2": "duration_predictor.flows.*.conv_dds.convs_dilated.2", + "dp.flows.*.convs.norms_1.0.gamma": "duration_predictor.flows.*.conv_dds.norms_1.0.weight", + "dp.flows.*.convs.norms_1.0.beta": "duration_predictor.flows.*.conv_dds.norms_1.0.bias", + "dp.flows.*.convs.norms_1.1.gamma": "duration_predictor.flows.*.conv_dds.norms_1.1.weight", + "dp.flows.*.convs.norms_1.1.beta": "duration_predictor.flows.*.conv_dds.norms_1.1.bias", + "dp.flows.*.convs.norms_1.2.gamma": "duration_predictor.flows.*.conv_dds.norms_1.2.weight", + "dp.flows.*.convs.norms_1.2.beta": "duration_predictor.flows.*.conv_dds.norms_1.2.bias", + "dp.flows.*.convs.norms_2.0.gamma": "duration_predictor.flows.*.conv_dds.norms_2.0.weight", + "dp.flows.*.convs.norms_2.0.beta": "duration_predictor.flows.*.conv_dds.norms_2.0.bias", + "dp.flows.*.convs.norms_2.1.gamma": "duration_predictor.flows.*.conv_dds.norms_2.1.weight", + "dp.flows.*.convs.norms_2.1.beta": "duration_predictor.flows.*.conv_dds.norms_2.1.bias", + "dp.flows.*.convs.norms_2.2.gamma": "duration_predictor.flows.*.conv_dds.norms_2.2.weight", + "dp.flows.*.convs.norms_2.2.beta": "duration_predictor.flows.*.conv_dds.norms_2.2.bias", + "dp.post_pre": "duration_predictor.post_conv_pre", + "dp.post_proj": "duration_predictor.post_conv_proj", + "dp.post_convs.convs_sep.*": "duration_predictor.post_conv_dds.convs_dilated.*", + "dp.post_convs.convs_1x1.*": "duration_predictor.post_conv_dds.convs_pointwise.*", + "dp.post_convs.norms_1.*.gamma": "duration_predictor.post_conv_dds.norms_1.*.weight", + "dp.post_convs.norms_1.*.beta": "duration_predictor.post_conv_dds.norms_1.*.bias", + "dp.post_convs.norms_2.*.gamma": "duration_predictor.post_conv_dds.norms_2.*.weight", + "dp.post_convs.norms_2.*.beta": "duration_predictor.post_conv_dds.norms_2.*.bias", + "dp.post_flows.0.logs": "duration_predictor.post_flows.0.log_scale", + "dp.post_flows.0.m": "duration_predictor.post_flows.0.translate", + "dp.post_flows.*.pre": "duration_predictor.post_flows.*.conv_pre", + "dp.post_flows.*.proj": "duration_predictor.post_flows.*.conv_proj", + "dp.post_flows.*.convs.convs_1x1.0": "duration_predictor.post_flows.*.conv_dds.convs_pointwise.0", + "dp.post_flows.*.convs.convs_1x1.1": "duration_predictor.post_flows.*.conv_dds.convs_pointwise.1", + "dp.post_flows.*.convs.convs_1x1.2": "duration_predictor.post_flows.*.conv_dds.convs_pointwise.2", + "dp.post_flows.*.convs.convs_sep.0": "duration_predictor.post_flows.*.conv_dds.convs_dilated.0", + "dp.post_flows.*.convs.convs_sep.1": "duration_predictor.post_flows.*.conv_dds.convs_dilated.1", + "dp.post_flows.*.convs.convs_sep.2": "duration_predictor.post_flows.*.conv_dds.convs_dilated.2", + "dp.post_flows.*.convs.norms_1.0.gamma": "duration_predictor.post_flows.*.conv_dds.norms_1.0.weight", + "dp.post_flows.*.convs.norms_1.0.beta": "duration_predictor.post_flows.*.conv_dds.norms_1.0.bias", + "dp.post_flows.*.convs.norms_1.1.gamma": "duration_predictor.post_flows.*.conv_dds.norms_1.1.weight", + "dp.post_flows.*.convs.norms_1.1.beta": "duration_predictor.post_flows.*.conv_dds.norms_1.1.bias", + "dp.post_flows.*.convs.norms_1.2.gamma": "duration_predictor.post_flows.*.conv_dds.norms_1.2.weight", + "dp.post_flows.*.convs.norms_1.2.beta": "duration_predictor.post_flows.*.conv_dds.norms_1.2.bias", + "dp.post_flows.*.convs.norms_2.0.gamma": "duration_predictor.post_flows.*.conv_dds.norms_2.0.weight", + "dp.post_flows.*.convs.norms_2.0.beta": "duration_predictor.post_flows.*.conv_dds.norms_2.0.bias", + "dp.post_flows.*.convs.norms_2.1.gamma": "duration_predictor.post_flows.*.conv_dds.norms_2.1.weight", + "dp.post_flows.*.convs.norms_2.1.beta": "duration_predictor.post_flows.*.conv_dds.norms_2.1.bias", + "dp.post_flows.*.convs.norms_2.2.gamma": "duration_predictor.post_flows.*.conv_dds.norms_2.2.weight", + "dp.post_flows.*.convs.norms_2.2.beta": "duration_predictor.post_flows.*.conv_dds.norms_2.2.bias", + "dp.cond": "duration_predictor.cond", # num_speakers > 1 +} +MAPPING_FLOW = { + "flow.flows.*.pre": "flow.flows.*.conv_pre", + "flow.flows.*.enc.in_layers.0": "flow.flows.*.wavenet.in_layers.0", + "flow.flows.*.enc.in_layers.1": "flow.flows.*.wavenet.in_layers.1", + "flow.flows.*.enc.in_layers.2": "flow.flows.*.wavenet.in_layers.2", + "flow.flows.*.enc.in_layers.3": "flow.flows.*.wavenet.in_layers.3", + "flow.flows.*.enc.res_skip_layers.0": "flow.flows.*.wavenet.res_skip_layers.0", + "flow.flows.*.enc.res_skip_layers.1": "flow.flows.*.wavenet.res_skip_layers.1", + "flow.flows.*.enc.res_skip_layers.2": "flow.flows.*.wavenet.res_skip_layers.2", + "flow.flows.*.enc.res_skip_layers.3": "flow.flows.*.wavenet.res_skip_layers.3", + "flow.flows.*.enc.cond_layer": "flow.flows.*.wavenet.cond_layer", # num_speakers > 1 + "flow.flows.*.post": "flow.flows.*.conv_post", +} +MAPPING_GENERATOR = { + "dec.conv_pre": "decoder.conv_pre", + "dec.ups.0": "decoder.upsampler.0", + "dec.ups.1": "decoder.upsampler.1", + "dec.ups.2": "decoder.upsampler.2", + "dec.ups.3": "decoder.upsampler.3", + "dec.resblocks.*.convs1.0": "decoder.resblocks.*.convs1.0", + "dec.resblocks.*.convs1.1": "decoder.resblocks.*.convs1.1", + "dec.resblocks.*.convs1.2": "decoder.resblocks.*.convs1.2", + "dec.resblocks.*.convs2.0": "decoder.resblocks.*.convs2.0", + "dec.resblocks.*.convs2.1": "decoder.resblocks.*.convs2.1", + "dec.resblocks.*.convs2.2": "decoder.resblocks.*.convs2.2", + "dec.conv_post": "decoder.conv_post", + "dec.cond": "decoder.cond", # num_speakers > 1 +} +MAPPING_POSTERIOR_ENCODER = { + "enc_q.pre": "posterior_encoder.conv_pre", + "enc_q.enc.in_layers.*": "posterior_encoder.wavenet.in_layers.*", + "enc_q.enc.res_skip_layers.*": "posterior_encoder.wavenet.res_skip_layers.*", + "enc_q.enc.cond_layer": "posterior_encoder.wavenet.cond_layer", # num_speakers > 1 + "enc_q.proj": "posterior_encoder.conv_proj", +} +MAPPING = { + **MAPPING_TEXT_ENCODER, + **MAPPING_STOCHASTIC_DURATION_PREDICTOR, + **MAPPING_FLOW, + **MAPPING_GENERATOR, + **MAPPING_POSTERIOR_ENCODER, + "emb_g": "embed_speaker", # num_speakers > 1 +} +TOP_LEVEL_KEYS = [] +IGNORE_KEYS = [] + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + # strip off the kernel dimension at the end (original weights are Conv1d) + if key.endswith(".k_proj") or key.endswith(".v_proj") or key.endswith(".q_proj") or key.endswith(".out_proj"): + value = value.squeeze(-1) + + if hf_shape != value.shape: + raise ValueError( + f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" + f" {value.shape} for {full_name}" + ) + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + elif weight_type == "running_mean": + hf_pointer.running_mean.data = value + elif weight_type == "running_var": + hf_pointer.running_var.data = value + elif weight_type == "num_batches_tracked": + hf_pointer.num_batches_tracked.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + ('.' + weight_type if weight_type is not None else '')} was initialized from {full_name}.") + + +def should_ignore(name, ignore_keys): + for key in ignore_keys: + if key.endswith(".*"): + if name.startswith(key[:-1]): + return True + elif ".*." in key: + prefix, suffix = key.split(".*.") + if prefix in name and suffix in name: + return True + elif key in name: + return True + return False + + +def recursively_load_weights(fairseq_dict, hf_model): + unused_weights = [] + + for name, value in fairseq_dict.items(): + if should_ignore(name, IGNORE_KEYS): + logger.info(f"{name} was ignored") + continue + + is_used = False + for key, mapped_key in MAPPING.items(): + if key.endswith(".*"): + key = key[:-1] + elif "*" in key: + prefix, suffix = key.split(".*.") + if prefix in name and suffix in name: + key = suffix + + if key in name: + is_used = True + if mapped_key.endswith(".*"): + layer_index = name.split(key)[-1].split(".")[0] + mapped_key = mapped_key.replace("*", layer_index) + elif "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + + # remap the layer index since we removed the Flip layers + if "flow.flows" in mapped_key: + layer_index = str(int(layer_index) // 2) + if "duration_predictor.flows" in mapped_key or "duration_predictor.post_flows" in mapped_key: + layer_index = str(int(layer_index) // 2 + 1) + + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "bias" in name: + weight_type = "bias" + elif "weight" in name: + weight_type = "weight" + elif "running_mean" in name: + weight_type = "running_mean" + elif "running_var" in name: + weight_type = "running_var" + elif "num_batches_tracked" in name: + weight_type = "num_batches_tracked" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +@torch.no_grad() +def convert_checkpoint( + pytorch_dump_folder_path, + checkpoint_path=None, + config_path=None, + vocab_path=None, + language=None, + num_speakers=None, + sampling_rate=None, + repo_id=None, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + if config_path is not None: + config = VitsConfig.from_pretrained(config_path) + else: + config = VitsConfig() + + if num_speakers: + config.num_speakers = num_speakers + config.speaker_embedding_size = 256 + + if sampling_rate: + config.sampling_rate = sampling_rate + + if checkpoint_path is None: + logger.info(f"***Converting model: facebook/mms-tts {language}***") + + vocab_path = hf_hub_download( + repo_id="facebook/mms-tts", + filename="vocab.txt", + subfolder=f"models/{language}", + ) + config_file = hf_hub_download( + repo_id="facebook/mms-tts", + filename="config.json", + subfolder=f"models/{language}", + ) + checkpoint_path = hf_hub_download( + repo_id="facebook/mms-tts", + filename="G_100000.pth", + subfolder=f"models/{language}", + ) + + with open(config_file, "r") as f: + data = f.read() + hps = json.loads(data) + + is_uroman = hps["data"]["training_files"].split(".")[-1] == "uroman" + if is_uroman: + logger.warning("For this checkpoint, you should use `uroman` to convert input text before tokenizing it!") + else: + logger.info(f"***Converting model: {checkpoint_path}***") + is_uroman = False + + # original VITS checkpoint + if vocab_path is None: + _pad = "_" + _punctuation = ';:,.!?¡¿—…"«»“” ' + _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" + symbols = _pad + _punctuation + _letters + _letters_ipa + symbol_to_id = {s: i for i, s in enumerate(symbols)} + phonemize = True + else: + # Save vocab as temporary json file + symbols = [line.replace("\n", "") for line in open(vocab_path, encoding="utf-8").readlines()] + symbol_to_id = {s: i for i, s in enumerate(symbols)} + # MMS-TTS does not use a token, so we set to the token used to space characters + _pad = symbols[0] + phonemize = False + + with tempfile.NamedTemporaryFile() as tf: + with open(tf.name, "w", encoding="utf-8") as f: + f.write(json.dumps(symbol_to_id, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + tokenizer = VitsTokenizer(tf.name, language=language, phonemize=phonemize, is_uroman=is_uroman, pad_token=_pad) + + config.vocab_size = len(symbols) + model = VitsModel(config) + + model.decoder.apply_weight_norm() + + orig_checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=True) + recursively_load_weights(orig_checkpoint["model"], model) + + model.decoder.remove_weight_norm() + + model.save_pretrained(pytorch_dump_folder_path) + tokenizer.save_pretrained(pytorch_dump_folder_path) + + if repo_id: + print("Pushing to the hub...") + tokenizer.push_to_hub(repo_id) + model.push_to_hub(repo_id) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint_path", default=None, type=str, help="Local path to original checkpoint") + parser.add_argument("--vocab_path", default=None, type=str, help="Path to vocab.txt") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument("--language", default=None, type=str, help="Tokenizer language (three-letter code)") + parser.add_argument("--num_speakers", default=None, type=int, help="Number of speakers") + parser.add_argument( + "--sampling_rate", default=None, type=int, help="Sampling rate on which the model was trained." + ) + parser.add_argument( + "--pytorch_dump_folder_path", required=True, default=None, type=str, help="Path to the output PyTorch model." + ) + parser.add_argument( + "--push_to_hub", default=None, type=str, help="Where to upload the converted model on the 🤗 hub." + ) + + args = parser.parse_args() + convert_checkpoint( + args.pytorch_dump_folder_path, + args.checkpoint_path, + args.config_path, + args.vocab_path, + args.language, + args.num_speakers, + args.sampling_rate, + args.push_to_hub, + ) diff --git a/docs/transformers/build/lib/transformers/models/vits/modeling_vits.py b/docs/transformers/build/lib/transformers/models/vits/modeling_vits.py new file mode 100644 index 0000000000000000000000000000000000000000..59483d3e61364e91b08521994ae1d9cb60745f47 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vits/modeling_vits.py @@ -0,0 +1,1493 @@ +# coding=utf-8 +# Copyright 2023 The Kakao Enterprise Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch VITS model.""" + +import math +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + ModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_vits import VitsConfig + + +logger = logging.get_logger(__name__) + + +# General docstring +_CONFIG_FOR_DOC = "VitsConfig" + + +@dataclass +class VitsModelOutput(ModelOutput): + """ + Describes the outputs for the VITS model, with potential hidden states and attentions. + + Args: + waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + The final audio waveform predicted by the model. + sequence_lengths (`torch.FloatTensor` of shape `(batch_size,)`): + The length in samples of each element in the `waveform` batch. + spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`): + The log-mel spectrogram predicted at the output of the flow model. This spectrogram is passed to the Hi-Fi + GAN decoder model to obtain the final audio waveform. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attention weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + waveform: Optional[torch.FloatTensor] = None + sequence_lengths: Optional[torch.FloatTensor] = None + spectrogram: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class VitsTextEncoderOutput(ModelOutput): + """ + Describes the outputs for the VITS text encoder model, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + prior_means (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The predicted mean values of the prior distribution for the latent text variables. + prior_log_variances (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + The predicted log-variance values of the prior distribution for the latent text variables. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attention weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + prior_means: Optional[torch.FloatTensor] = None + prior_log_variances: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels): + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :num_channels, :]) + s_act = torch.sigmoid(in_act[:, num_channels:, :]) + acts = t_act * s_act + return acts + + +def _unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + reverse=False, + tail_bound=5.0, + min_bin_width=1e-3, + min_bin_height=1e-3, + min_derivative=1e-3, +): + """ + This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the + `tail_bound`, the transform behaves as an identity function. + + Args: + inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: + Second half of the hidden-states input to the Vits convolutional flow module. + unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): + First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection + layer in the convolutional flow module + unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): + Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection + layer in the convolutional flow module + unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): + Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection + layer in the convolutional flow module + reverse (`bool`, *optional*, defaults to `False`): + Whether the model is being run in reverse mode. + tail_bound (`float`, *optional* defaults to 5): + Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the + transform behaves as an identity function. + min_bin_width (`float`, *optional*, defaults to 1e-3): + Minimum bin value across the width dimension for the piecewise rational quadratic function. + min_bin_height (`float`, *optional*, defaults to 1e-3): + Minimum bin value across the height dimension for the piecewise rational quadratic function. + min_derivative (`float`, *optional*, defaults to 1e-3): + Minimum bin value across the derivatives for the piecewise rational quadratic function. + Returns: + outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: + Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits + applied. + log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: + Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound` + limits applied. + """ + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + log_abs_det = torch.zeros_like(inputs) + constant = np.log(np.exp(1 - min_derivative) - 1) + + unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1)) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + log_abs_det[outside_interval_mask] = 0.0 + + outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + reverse=reverse, + tail_bound=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + return outputs, log_abs_det + + +def _rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + reverse, + tail_bound, + min_bin_width, + min_bin_height, + min_derivative, +): + """ + This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the + function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`. + + Args: + inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: + Second half of the hidden-states input to the Vits convolutional flow module. + unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): + First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection + layer in the convolutional flow module + unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): + Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection + layer in the convolutional flow module + unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`): + Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection + layer in the convolutional flow module + reverse (`bool`): + Whether the model is being run in reverse mode. + tail_bound (`float`): + Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the + transform behaves as an identity function. + min_bin_width (`float`): + Minimum bin value across the width dimension for the piecewise rational quadratic function. + min_bin_height (`float`): + Minimum bin value across the height dimension for the piecewise rational quadratic function. + min_derivative (`float`): + Minimum bin value across the derivatives for the piecewise rational quadratic function. + Returns: + outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: + Hidden-states as transformed by the piecewise rational quadratic function. + log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`: + Logarithm of the absolute value of the determinants corresponding to the `outputs`. + """ + upper_bound = tail_bound + lower_bound = -tail_bound + + if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}") + if min_bin_height * num_bins > 1.0: + raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}") + + widths = nn.functional.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound + cumwidths[..., 0] = lower_bound + cumwidths[..., -1] = upper_bound + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives) + + heights = nn.functional.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (upper_bound - lower_bound) * cumheights + lower_bound + cumheights[..., 0] = lower_bound + cumheights[..., -1] = upper_bound + heights = cumheights[..., 1:] - cumheights[..., :-1] + + bin_locations = cumheights if reverse else cumwidths + bin_locations[..., -1] += 1e-6 + bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + bin_idx = bin_idx[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta + if not reverse: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta) + denominator = input_delta + intermediate1 * theta_one_minus_theta + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator) + return outputs, log_abs_det + else: + # find the roots of a quadratic equation + intermediate2 = inputs - input_cumheights + intermediate3 = intermediate2 * intermediate1 + a = input_heights * (input_delta - input_derivatives) + intermediate3 + b = input_heights * input_derivatives - intermediate3 + c = -input_delta * intermediate2 + + discriminant = b.pow(2) - 4 * a * c + if not (discriminant >= 0).all(): + raise RuntimeError(f"invalid discriminant {discriminant}") + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + intermediate1 * theta_one_minus_theta + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator) + return outputs, -log_abs_det + + +class VitsWaveNet(torch.nn.Module): + def __init__(self, config: VitsConfig, num_layers: int): + super().__init__() + self.hidden_size = config.hidden_size + self.num_layers = num_layers + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.dropout = nn.Dropout(config.wavenet_dropout) + + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + else: + weight_norm = nn.utils.weight_norm + + if config.speaker_embedding_size != 0: + cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1) + self.cond_layer = weight_norm(cond_layer, name="weight") + + for i in range(num_layers): + dilation = config.wavenet_dilation_rate**i + padding = (config.wavenet_kernel_size * dilation - dilation) // 2 + in_layer = torch.nn.Conv1d( + in_channels=config.hidden_size, + out_channels=2 * config.hidden_size, + kernel_size=config.wavenet_kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < num_layers - 1: + res_skip_channels = 2 * config.hidden_size + else: + res_skip_channels = config.hidden_size + + res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1) + res_skip_layer = weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, inputs, padding_mask, global_conditioning=None): + outputs = torch.zeros_like(inputs) + num_channels_tensor = torch.IntTensor([self.hidden_size]) + + if global_conditioning is not None: + global_conditioning = self.cond_layer(global_conditioning) + + for i in range(self.num_layers): + hidden_states = self.in_layers[i](inputs) + + if global_conditioning is not None: + cond_offset = i * 2 * self.hidden_size + global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :] + else: + global_states = torch.zeros_like(hidden_states) + + acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0]) + acts = self.dropout(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.num_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_size, :] + inputs = (inputs + res_acts) * padding_mask + outputs = outputs + res_skip_acts[:, self.hidden_size :, :] + else: + outputs = outputs + res_skip_acts + + return outputs * padding_mask + + def remove_weight_norm(self): + if self.speaker_embedding_size != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for layer in self.in_layers: + torch.nn.utils.remove_weight_norm(layer) + for layer in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(layer) + + +class VitsPosteriorEncoder(nn.Module): + def __init__(self, config: VitsConfig): + super().__init__() + self.out_channels = config.flow_size + + self.conv_pre = nn.Conv1d(config.spectrogram_bins, config.hidden_size, 1) + self.wavenet = VitsWaveNet(config, num_layers=config.posterior_encoder_num_wavenet_layers) + self.conv_proj = nn.Conv1d(config.hidden_size, self.out_channels * 2, 1) + + def forward(self, inputs, padding_mask, global_conditioning=None): + inputs = self.conv_pre(inputs) * padding_mask + inputs = self.wavenet(inputs, padding_mask, global_conditioning) + stats = self.conv_proj(inputs) * padding_mask + mean, log_stddev = torch.split(stats, self.out_channels, dim=1) + sampled = (mean + torch.randn_like(mean) * torch.exp(log_stddev)) * padding_mask + return sampled, mean, log_stddev + + +# Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock +class HifiGanResidualBlock(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1): + super().__init__() + self.leaky_relu_slope = leaky_relu_slope + + self.convs1 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=dilation[i], + padding=self.get_padding(kernel_size, dilation[i]), + ) + for i in range(len(dilation)) + ] + ) + self.convs2 = nn.ModuleList( + [ + nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + for _ in range(len(dilation)) + ] + ) + + def get_padding(self, kernel_size, dilation=1): + return (kernel_size * dilation - dilation) // 2 + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + for layer in self.convs1: + weight_norm(layer) + for layer in self.convs2: + weight_norm(layer) + + def remove_weight_norm(self): + for layer in self.convs1: + nn.utils.remove_weight_norm(layer) + for layer in self.convs2: + nn.utils.remove_weight_norm(layer) + + def forward(self, hidden_states): + for conv1, conv2 in zip(self.convs1, self.convs2): + residual = hidden_states + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv1(hidden_states) + hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope) + hidden_states = conv2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class VitsHifiGan(nn.Module): + def __init__(self, config: VitsConfig): + super().__init__() + self.config = config + self.num_kernels = len(config.resblock_kernel_sizes) + self.num_upsamples = len(config.upsample_rates) + self.conv_pre = nn.Conv1d( + config.flow_size, + config.upsample_initial_channel, + kernel_size=7, + stride=1, + padding=3, + ) + + self.upsampler = nn.ModuleList() + for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)): + self.upsampler.append( + nn.ConvTranspose1d( + config.upsample_initial_channel // (2**i), + config.upsample_initial_channel // (2 ** (i + 1)), + kernel_size=kernel_size, + stride=upsample_rate, + padding=(kernel_size - upsample_rate) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.upsampler)): + channels = config.upsample_initial_channel // (2 ** (i + 1)) + for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes): + self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope)) + + self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False) + + if config.speaker_embedding_size != 0: + self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1) + + def apply_weight_norm(self): + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + for layer in self.upsampler: + weight_norm(layer) + for layer in self.resblocks: + layer.apply_weight_norm() + + def remove_weight_norm(self): + for layer in self.upsampler: + nn.utils.remove_weight_norm(layer) + for layer in self.resblocks: + layer.remove_weight_norm() + + def forward( + self, spectrogram: torch.FloatTensor, global_conditioning: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + r""" + Converts a spectrogram into a speech waveform. + + Args: + spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`): + Tensor containing the spectrograms. + global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*): + Tensor containing speaker embeddings, for multispeaker models. + + Returns: + `torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform. + """ + hidden_states = self.conv_pre(spectrogram) + + if global_conditioning is not None: + hidden_states = hidden_states + self.cond(global_conditioning) + + for i in range(self.num_upsamples): + hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope) + hidden_states = self.upsampler[i](hidden_states) + + res_state = self.resblocks[i * self.num_kernels](hidden_states) + for j in range(1, self.num_kernels): + res_state += self.resblocks[i * self.num_kernels + j](hidden_states) + hidden_states = res_state / self.num_kernels + + hidden_states = nn.functional.leaky_relu(hidden_states) + hidden_states = self.conv_post(hidden_states) + waveform = torch.tanh(hidden_states) + return waveform + + +class VitsResidualCouplingLayer(nn.Module): + def __init__(self, config: VitsConfig): + super().__init__() + self.half_channels = config.flow_size // 2 + + self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1) + self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers) + self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1) + + def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): + first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) + hidden_states = self.conv_pre(first_half) * padding_mask + hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning) + mean = self.conv_post(hidden_states) * padding_mask + log_stddev = torch.zeros_like(mean) + + if not reverse: + second_half = mean + second_half * torch.exp(log_stddev) * padding_mask + outputs = torch.cat([first_half, second_half], dim=1) + log_determinant = torch.sum(log_stddev, [1, 2]) + return outputs, log_determinant + else: + second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask + outputs = torch.cat([first_half, second_half], dim=1) + return outputs, None + + +class VitsResidualCouplingBlock(nn.Module): + def __init__(self, config: VitsConfig): + super().__init__() + self.flows = nn.ModuleList() + for _ in range(config.prior_encoder_num_flows): + self.flows.append(VitsResidualCouplingLayer(config)) + + def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): + if not reverse: + for flow in self.flows: + inputs, _ = flow(inputs, padding_mask, global_conditioning) + inputs = torch.flip(inputs, [1]) + else: + for flow in reversed(self.flows): + inputs = torch.flip(inputs, [1]) + inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True) + return inputs + + +class VitsDilatedDepthSeparableConv(nn.Module): + def __init__(self, config: VitsConfig, dropout_rate=0.0): + super().__init__() + kernel_size = config.duration_predictor_kernel_size + channels = config.hidden_size + self.num_layers = config.depth_separable_num_layers + + self.dropout = nn.Dropout(dropout_rate) + self.convs_dilated = nn.ModuleList() + self.convs_pointwise = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(self.num_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_dilated.append( + nn.Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_pointwise.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(nn.LayerNorm(channels)) + self.norms_2.append(nn.LayerNorm(channels)) + + def forward(self, inputs, padding_mask, global_conditioning=None): + if global_conditioning is not None: + inputs = inputs + global_conditioning + + for i in range(self.num_layers): + hidden_states = self.convs_dilated[i](inputs * padding_mask) + hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1) + hidden_states = nn.functional.gelu(hidden_states) + hidden_states = self.convs_pointwise[i](hidden_states) + hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1) + hidden_states = nn.functional.gelu(hidden_states) + hidden_states = self.dropout(hidden_states) + inputs = inputs + hidden_states + + return inputs * padding_mask + + +class VitsConvFlow(nn.Module): + def __init__(self, config: VitsConfig): + super().__init__() + self.filter_channels = config.hidden_size + self.half_channels = config.depth_separable_channels // 2 + self.num_bins = config.duration_predictor_flow_bins + self.tail_bound = config.duration_predictor_tail_bound + + self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1) + self.conv_dds = VitsDilatedDepthSeparableConv(config) + self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1) + + def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): + first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1) + + hidden_states = self.conv_pre(first_half) + hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning) + hidden_states = self.conv_proj(hidden_states) * padding_mask + + batch_size, channels, length = first_half.shape + hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2) + + unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :] + + second_half, log_abs_det = _unconstrained_rational_quadratic_spline( + second_half, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + reverse=reverse, + tail_bound=self.tail_bound, + ) + + outputs = torch.cat([first_half, second_half], dim=1) * padding_mask + if not reverse: + log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2]) + return outputs, log_determinant + else: + return outputs, None + + +class VitsElementwiseAffine(nn.Module): + def __init__(self, config: VitsConfig): + super().__init__() + self.channels = config.depth_separable_channels + self.translate = nn.Parameter(torch.zeros(self.channels, 1)) + self.log_scale = nn.Parameter(torch.zeros(self.channels, 1)) + + def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False): + if not reverse: + outputs = self.translate + torch.exp(self.log_scale) * inputs + outputs = outputs * padding_mask + log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2]) + return outputs, log_determinant + else: + outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask + return outputs, None + + +class VitsStochasticDurationPredictor(nn.Module): + def __init__(self, config): + super().__init__() + embed_dim = config.speaker_embedding_size + filter_channels = config.hidden_size + + self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1) + self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.conv_dds = VitsDilatedDepthSeparableConv( + config, + dropout_rate=config.duration_predictor_dropout, + ) + + if embed_dim != 0: + self.cond = nn.Conv1d(embed_dim, filter_channels, 1) + + self.flows = nn.ModuleList() + self.flows.append(VitsElementwiseAffine(config)) + for _ in range(config.duration_predictor_num_flows): + self.flows.append(VitsConvFlow(config)) + + self.post_conv_pre = nn.Conv1d(1, filter_channels, 1) + self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_conv_dds = VitsDilatedDepthSeparableConv( + config, + dropout_rate=config.duration_predictor_dropout, + ) + + self.post_flows = nn.ModuleList() + self.post_flows.append(VitsElementwiseAffine(config)) + for _ in range(config.duration_predictor_num_flows): + self.post_flows.append(VitsConvFlow(config)) + + def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0): + inputs = torch.detach(inputs) + inputs = self.conv_pre(inputs) + + if global_conditioning is not None: + global_conditioning = torch.detach(global_conditioning) + inputs = inputs + self.cond(global_conditioning) + + inputs = self.conv_dds(inputs, padding_mask) + inputs = self.conv_proj(inputs) * padding_mask + + if not reverse: + hidden_states = self.post_conv_pre(durations) + hidden_states = self.post_conv_dds(hidden_states, padding_mask) + hidden_states = self.post_conv_proj(hidden_states) * padding_mask + + random_posterior = ( + torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype) + * padding_mask + ) + log_determinant_posterior_sum = 0 + latents_posterior = random_posterior + for flow in self.post_flows: + latents_posterior, log_determinant = flow( + latents_posterior, padding_mask, global_conditioning=inputs + hidden_states + ) + latents_posterior = torch.flip(latents_posterior, [1]) + log_determinant_posterior_sum += log_determinant + + first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1) + + log_determinant_posterior_sum += torch.sum( + (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2] + ) + logq = ( + torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2]) + - log_determinant_posterior_sum + ) + + first_half = (durations - torch.sigmoid(first_half)) * padding_mask + first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask + log_determinant_sum = torch.sum(-first_half, [1, 2]) + + latents = torch.cat([first_half, second_half], dim=1) + for flow in self.flows: + latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs) + latents = torch.flip(latents, [1]) + log_determinant_sum += log_determinant + + nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum + return nll + logq + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + + latents = ( + torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype) + * noise_scale + ) + for flow in flows: + latents = torch.flip(latents, [1]) + latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True) + + log_duration, _ = torch.split(latents, [1, 1], dim=1) + return log_duration + + +class VitsDurationPredictor(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = config.duration_predictor_kernel_size + filter_channels = config.duration_predictor_filter_channels + + self.dropout = nn.Dropout(config.duration_predictor_dropout) + self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) + self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if config.speaker_embedding_size != 0: + self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1) + + def forward(self, inputs, padding_mask, global_conditioning=None): + inputs = torch.detach(inputs) + + if global_conditioning is not None: + global_conditioning = torch.detach(global_conditioning) + inputs = inputs + self.cond(global_conditioning) + + inputs = self.conv_1(inputs * padding_mask) + inputs = torch.relu(inputs) + inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1) + inputs = self.dropout(inputs) + + inputs = self.conv_2(inputs * padding_mask) + inputs = torch.relu(inputs) + inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1) + inputs = self.dropout(inputs) + + inputs = self.proj(inputs * padding_mask) + return inputs * padding_mask + + +class VitsAttention(nn.Module): + """Multi-headed attention with relative positional representation.""" + + def __init__(self, config: VitsConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout + self.window_size = config.window_size + + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim**-0.5 + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}" + f" and `num_attention_heads`: {self.num_heads})." + ) + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias) + + if self.window_size: + self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling) + self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if self.window_size is not None: + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len) + relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1)) + rel_pos_bias = self._relative_position_to_absolute_position(relative_logits) + attn_weights += rel_pos_bias + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + if self.window_size is not None: + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len) + relative_weights = self._absolute_position_to_relative_position(attn_probs) + rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings) + attn_output += rel_pos_bias + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned aross GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + + def _get_relative_embeddings(self, relative_embeddings, length): + pad_length = max(length - (self.window_size + 1), 0) + if pad_length > 0: + relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) + + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + return relative_embeddings[:, slice_start_position:slice_end_position] + + def _relative_position_to_absolute_position(self, x): + batch_heads, length, _ = x.size() + + # Concat columns of pad to shift from relative to absolute indexing. + x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0]) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch_heads, length * 2 * length]) + x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0]) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1]) + x_final = x_final[:, :length, length - 1 :] + return x_final + + def _absolute_position_to_relative_position(self, x): + batch_heads, length, _ = x.size() + + # Pad along column + x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0]) + x_flat = x.view([batch_heads, length * (2 * length - 1)]) + + # Add 0's in the beginning that will skew the elements after reshape + x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0]) + x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:] + return x_final + + +class VitsFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size) + self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size) + self.dropout = nn.Dropout(config.activation_dropout) + + if isinstance(config.hidden_act, str): + self.act_fn = ACT2FN[config.hidden_act] + else: + self.act_fn = config.hidden_act + + if config.ffn_kernel_size > 1: + pad_left = (config.ffn_kernel_size - 1) // 2 + pad_right = config.ffn_kernel_size // 2 + self.padding = [pad_left, pad_right, 0, 0, 0, 0] + else: + self.padding = None + + def forward(self, hidden_states, padding_mask): + hidden_states = hidden_states.permute(0, 2, 1) + padding_mask = padding_mask.permute(0, 2, 1) + + hidden_states = hidden_states * padding_mask + if self.padding is not None: + hidden_states = nn.functional.pad(hidden_states, self.padding) + + hidden_states = self.conv_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states * padding_mask + if self.padding is not None: + hidden_states = nn.functional.pad(hidden_states, self.padding) + + hidden_states = self.conv_2(hidden_states) + hidden_states = hidden_states * padding_mask + + hidden_states = hidden_states.permute(0, 2, 1) + return hidden_states + + +class VitsEncoderLayer(nn.Module): + def __init__(self, config: VitsConfig): + super().__init__() + self.attention = VitsAttention(config) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = VitsFeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + padding_mask: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + residual = hidden_states + hidden_states, attn_weights = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.layer_norm(residual + hidden_states) + + residual = hidden_states + hidden_states = self.feed_forward(hidden_states, padding_mask) + hidden_states = self.dropout(hidden_states) + hidden_states = self.final_layer_norm(residual + hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class VitsEncoder(nn.Module): + def __init__(self, config: VitsConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + self.layerdrop = config.layerdrop + + def forward( + self, + hidden_states: torch.FloatTensor, + padding_mask: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + hidden_states = hidden_states * padding_mask + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + + skip_the_layer = self.training and (dropout_probability < self.layerdrop) + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + padding_mask, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask=attention_mask, + padding_mask=padding_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = hidden_states * padding_mask + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class VitsTextEncoder(nn.Module): + """ + Transformer encoder that uses relative positional representation instead of absolute positional encoding. + """ + + def __init__(self, config: VitsConfig): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.encoder = VitsEncoder(config) + self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.Tensor, + padding_mask: torch.FloatTensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], VitsTextEncoderOutput]: + hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size) + + encoder_outputs = self.encoder( + hidden_states=hidden_states, + padding_mask=padding_mask, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state + + stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask + prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2) + + if not return_dict: + outputs = (last_hidden_state, prior_means, prior_log_variances) + encoder_outputs[1:] + return outputs + + return VitsTextEncoderOutput( + last_hidden_state=last_hidden_state, + prior_means=prior_means, + prior_log_variances=prior_log_variances, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class VitsPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = VitsConfig + base_model_prefix = "vits" + main_input_name = "input_ids" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +VITS_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`VitsConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +VITS_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + speaker_id (`int`, *optional*): + Which speaker embedding to use. Only used for multispeaker models. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The complete VITS model, for text-to-speech synthesis.", + VITS_START_DOCSTRING, +) +class VitsModel(VitsPreTrainedModel): + def __init__(self, config: VitsConfig): + super().__init__(config) + self.config = config + self.text_encoder = VitsTextEncoder(config) + self.flow = VitsResidualCouplingBlock(config) + self.decoder = VitsHifiGan(config) + + if config.use_stochastic_duration_prediction: + self.duration_predictor = VitsStochasticDurationPredictor(config) + else: + self.duration_predictor = VitsDurationPredictor(config) + + if config.num_speakers > 1: + self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size) + + # This is used only for training. + self.posterior_encoder = VitsPosteriorEncoder(config) + + # These parameters control the synthesised speech properties + self.speaking_rate = config.speaking_rate + self.noise_scale = config.noise_scale + self.noise_scale_duration = config.noise_scale_duration + + # Initialize weights and apply final processing + self.post_init() + + def get_encoder(self): + return self.text_encoder + + @add_start_docstrings_to_model_forward(VITS_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=VitsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + speaker_id: Optional[int] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple[Any], VitsModelOutput]: + r""" + labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*): + Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss + computation. + + Returns: + + Example: + + ```python + >>> from transformers import VitsTokenizer, VitsModel, set_seed + >>> import torch + + >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng") + >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng") + + >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt") + + >>> set_seed(555) # make deterministic + + >>> with torch.no_grad(): + ... outputs = model(inputs["input_ids"]) + >>> outputs.waveform.shape + torch.Size([1, 45824]) + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + raise NotImplementedError("Training of VITS is not supported yet.") + + mask_dtype = self.text_encoder.embed_tokens.weight.dtype + if attention_mask is not None: + input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype) + else: + input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype) + + if self.config.num_speakers > 1 and speaker_id is not None: + if not 0 <= speaker_id < self.config.num_speakers: + raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.") + if isinstance(speaker_id, int): + speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device) + speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1) + else: + speaker_embeddings = None + + text_encoder_output = self.text_encoder( + input_ids=input_ids, + padding_mask=input_padding_mask, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state + hidden_states = hidden_states.transpose(1, 2) + input_padding_mask = input_padding_mask.transpose(1, 2) + prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means + prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances + + if self.config.use_stochastic_duration_prediction: + log_duration = self.duration_predictor( + hidden_states, + input_padding_mask, + speaker_embeddings, + reverse=True, + noise_scale=self.noise_scale_duration, + ) + else: + log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings) + + length_scale = 1.0 / self.speaking_rate + duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale) + predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long() + + # Create a padding mask for the output lengths of shape (batch, 1, max_output_length) + indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device) + output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1) + output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype) + + # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length) + attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1) + batch_size, _, output_length, input_length = attn_mask.shape + cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1) + indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device) + valid_indices = indices.unsqueeze(0) < cum_duration + valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length) + padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1] + attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask + + # Expand prior distribution + prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2) + prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2) + + prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale + latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True) + + spectrogram = latents * output_padding_mask + waveform = self.decoder(spectrogram, speaker_embeddings) + waveform = waveform.squeeze(1) + sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates) + + if not return_dict: + outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:] + return outputs + + return VitsModelOutput( + waveform=waveform, + sequence_lengths=sequence_lengths, + spectrogram=spectrogram, + hidden_states=text_encoder_output.hidden_states, + attentions=text_encoder_output.attentions, + ) + + +__all__ = ["VitsModel", "VitsPreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/vits/tokenization_vits.py b/docs/transformers/build/lib/transformers/models/vits/tokenization_vits.py new file mode 100644 index 0000000000000000000000000000000000000000..ca40c80c124cae78e04bc65fd67982dcbd6020d6 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vits/tokenization_vits.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright 2023 The Kakao Enterprise Authors, the MMS-TTS Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for VITS.""" + +import json +import os +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +from ...tokenization_utils import PreTrainedTokenizer +from ...utils import is_phonemizer_available, is_uroman_available, logging + + +if is_phonemizer_available(): + import phonemizer + +if is_uroman_available(): + import uroman as ur + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"} + + +def has_non_roman_characters(input_string): + # Find any character outside the ASCII range + non_roman_pattern = re.compile(r"[^\x00-\x7F]") + + # Search the input string for non-Roman characters + match = non_roman_pattern.search(input_string) + has_non_roman = match is not None + return has_non_roman + + +class VitsTokenizer(PreTrainedTokenizer): + """ + Construct a VITS tokenizer. Also supports MMS-TTS. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + language (`str`, *optional*): + Language identifier. + add_blank (`bool`, *optional*, defaults to `True`): + Whether to insert token id 0 in between the other tokens. + normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the input text by removing all casing and punctuation. + phonemize (`bool`, *optional*, defaults to `True`): + Whether to convert the input text into phonemes. + is_uroman (`bool`, *optional*, defaults to `False`): + Whether the `uroman` Romanizer needs to be applied to the input text prior to tokenizing. + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + pad_token="", + unk_token="", + language=None, + add_blank=True, + normalize=True, + phonemize=True, + is_uroman=False, + **kwargs, + ) -> None: + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + + self.decoder = {v: k for k, v in self.encoder.items()} + self.language = language + self.add_blank = add_blank + self.normalize = normalize + self.phonemize = phonemize + + self.is_uroman = is_uroman + + super().__init__( + pad_token=pad_token, + unk_token=unk_token, + language=language, + add_blank=add_blank, + normalize=normalize, + phonemize=phonemize, + is_uroman=is_uroman, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.encoder) + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def normalize_text(self, input_string): + """Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased.""" + all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys()) + filtered_text = "" + + i = 0 + while i < len(input_string): + found_match = False + for word in all_vocabulary: + if input_string[i : i + len(word)] == word: + filtered_text += word + i += len(word) + found_match = True + break + + if not found_match: + filtered_text += input_string[i].lower() + i += 1 + + return filtered_text + + def _preprocess_char(self, text): + """Special treatment of characters in certain languages""" + if self.language == "ron": + text = text.replace("ț", "ţ") + return text + + def prepare_for_tokenization( + self, text: str, is_split_into_words: bool = False, normalize: Optional[bool] = None, **kwargs + ) -> Tuple[str, Dict[str, Any]]: + """ + Performs any necessary transformations before tokenization. + + This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the + `kwargs` at the end of the encoding process to be sure all the arguments have been used. + + Args: + text (`str`): + The text to prepare. + is_split_into_words (`bool`, *optional*, defaults to `False`): + Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the + tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) + which it will tokenize. + normalize (`bool`, *optional*, defaults to `None`): + Whether or not to apply punctuation and casing normalization to the text inputs. Typically, VITS is + trained on lower-cased and un-punctuated text. Hence, normalization is used to ensure that the input + text consists only of lower-case characters. + kwargs (`Dict[str, Any]`, *optional*): + Keyword arguments to use for the tokenization. + + Returns: + `Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs. + """ + normalize = normalize if normalize is not None else self.normalize + + if normalize: + # normalise for casing + text = self.normalize_text(text) + + filtered_text = self._preprocess_char(text) + + if has_non_roman_characters(filtered_text) and self.is_uroman: + if not is_uroman_available(): + logger.warning( + "Text to the tokenizer contains non-Roman characters. To apply the `uroman` pre-processing " + "step automatically, ensure the `uroman` Romanizer is installed with: `pip install uroman` " + "Note `uroman` requires python version >= 3.10" + "Otherwise, apply the Romanizer manually as per the instructions: https://github.com/isi-nlp/uroman" + ) + else: + uroman = ur.Uroman() + filtered_text = uroman.romanize_string(filtered_text) + + if self.phonemize: + if not is_phonemizer_available(): + raise ImportError("Please install the `phonemizer` Python package to use this tokenizer.") + + filtered_text = phonemizer.phonemize( + filtered_text, + language="en-us", + backend="espeak", + strip=True, + preserve_punctuation=True, + with_stress=True, + ) + filtered_text = re.sub(r"\s+", " ", filtered_text) + elif normalize: + # strip any chars outside of the vocab (punctuation) + filtered_text = "".join(list(filter(lambda char: char in self.encoder, filtered_text))).strip() + + return filtered_text, kwargs + + def _tokenize(self, text: str) -> List[str]: + """Tokenize a string by inserting the `` token at the boundary between adjacent characters.""" + tokens = list(text) + + if self.add_blank: + interspersed = [self._convert_id_to_token(0)] * (len(tokens) * 2 + 1) + interspersed[1::2] = tokens + tokens = interspersed + + return tokens + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + if self.add_blank and len(tokens) > 1: + tokens = tokens[1::2] + return "".join(tokens) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.decoder.get(index) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Union[Tuple[str], None]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + return (vocab_file,) + + +__all__ = ["VitsTokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/vivit/__init__.py b/docs/transformers/build/lib/transformers/models/vivit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13c52cbcff0b3fc6dc3255f41729754991dd3c8f --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vivit/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_vivit import * + from .image_processing_vivit import * + from .modeling_vivit import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/vivit/configuration_vivit.py b/docs/transformers/build/lib/transformers/models/vivit/configuration_vivit.py new file mode 100644 index 0000000000000000000000000000000000000000..42863454e81c278ee7af62fee06b19ec9b6ca2f0 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vivit/configuration_vivit.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ViViT model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class VivitConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`VivitModel`]. It is used to instantiate a ViViT + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the ViViT + [google/vivit-b-16x2-kinetics400](https://huggingface.co/google/vivit-b-16x2-kinetics400) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + num_frames (`int`, *optional*, defaults to 32): + The number of frames in each video. + tubelet_size (`List[int]`, *optional*, defaults to `[2, 16, 16]`): + The size (resolution) of each tubelet. + num_channels (`int`, *optional*, defaults to 3): + The number of input channels. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_fast"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"`, `"gelu_fast"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + qkv_bias (`bool`, *optional*, defaults to `True`): + Whether to add a bias to the queries, keys and values. + + Example: + + ```python + >>> from transformers import VivitConfig, VivitModel + + >>> # Initializing a ViViT google/vivit-b-16x2-kinetics400 style configuration + >>> configuration = VivitConfig() + + >>> # Initializing a model (with random weights) from the google/vivit-b-16x2-kinetics400 style configuration + >>> model = VivitModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "vivit" + + def __init__( + self, + image_size=224, + num_frames=32, + tubelet_size=[2, 16, 16], + num_channels=3, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu_fast", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + initializer_range=0.02, + layer_norm_eps=1e-06, + qkv_bias=True, + **kwargs, + ): + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.image_size = image_size + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.num_channels = num_channels + self.qkv_bias = qkv_bias + + super().__init__(**kwargs) + + +__all__ = ["VivitConfig"] diff --git a/docs/transformers/build/lib/transformers/models/vivit/convert_vivit_flax_to_pytorch.py b/docs/transformers/build/lib/transformers/models/vivit/convert_vivit_flax_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b5e1cfda311b51f62b85a059e420841feda708 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vivit/convert_vivit_flax_to_pytorch.py @@ -0,0 +1,231 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Flax ViViT checkpoints from the original repository to PyTorch. URL: +https://github.com/google-research/scenic/tree/main/scenic/projects/vivit +""" + +import argparse +import json +import os.path +from collections import OrderedDict + +import numpy as np +import requests +import torch +from flax.training.checkpoints import restore_checkpoint +from huggingface_hub import hf_hub_download + +from transformers import VivitConfig, VivitForVideoClassification, VivitImageProcessor +from transformers.image_utils import PILImageResampling + + +def download_checkpoint(path): + url = "https://storage.googleapis.com/scenic-bucket/vivit/kinetics_400/vivit_base_16x2_unfactorized/checkpoint" + + with open(path, "wb") as f: + with requests.get(url, stream=True) as req: + for chunk in req.iter_content(chunk_size=2048): + f.write(chunk) + + +def get_vivit_config() -> VivitConfig: + config = VivitConfig() + + config.num_labels = 400 + repo_id = "huggingface/label-files" + filename = "kinetics400-id2label.json" + + id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r")) + id2label = {int(k): v for k, v in id2label.items()} + config.id2label = id2label + config.label2id = {v: k for k, v in id2label.items()} + return config + + +# We will verify our results on a video of eating spaghetti +# Frame indices used: [ 47, 51, 55, 59, 63, 67, 71, 75, 80, 84, 88, 92, 96, 100, 104, 108, 113, 117, +# 121, 125, 129, 133, 137, 141, 146, 150, 154, 158, 162, 166, 170, 174] +def prepare_video(): + file = hf_hub_download( + repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti_32_frames.npy", repo_type="dataset" + ) + video = np.load(file) + return list(video) + + +def transform_attention(current: np.ndarray): + if np.ndim(current) == 2: + return transform_attention_bias(current) + + elif np.ndim(current) == 3: + return transform_attention_kernel(current) + + else: + raise Exception(f"Invalid number of dimensions: {np.ndim(current)}") + + +def transform_attention_bias(current: np.ndarray): + return current.flatten() + + +def transform_attention_kernel(current: np.ndarray): + return np.reshape(current, (current.shape[0], current.shape[1] * current.shape[2])).T + + +def transform_attention_output_weight(current: np.ndarray): + return np.reshape(current, (current.shape[0] * current.shape[1], current.shape[2])).T + + +def transform_state_encoder_block(state_dict, i): + state = state_dict["optimizer"]["target"]["Transformer"][f"encoderblock_{i}"] + + prefix = f"encoder.layer.{i}." + new_state = { + prefix + "intermediate.dense.bias": state["MlpBlock_0"]["Dense_0"]["bias"], + prefix + "intermediate.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_0"]["kernel"]), + prefix + "output.dense.bias": state["MlpBlock_0"]["Dense_1"]["bias"], + prefix + "output.dense.weight": np.transpose(state["MlpBlock_0"]["Dense_1"]["kernel"]), + prefix + "layernorm_before.bias": state["LayerNorm_0"]["bias"], + prefix + "layernorm_before.weight": state["LayerNorm_0"]["scale"], + prefix + "layernorm_after.bias": state["LayerNorm_1"]["bias"], + prefix + "layernorm_after.weight": state["LayerNorm_1"]["scale"], + prefix + "attention.attention.query.bias": transform_attention( + state["MultiHeadDotProductAttention_0"]["query"]["bias"] + ), + prefix + "attention.attention.query.weight": transform_attention( + state["MultiHeadDotProductAttention_0"]["query"]["kernel"] + ), + prefix + "attention.attention.key.bias": transform_attention( + state["MultiHeadDotProductAttention_0"]["key"]["bias"] + ), + prefix + "attention.attention.key.weight": transform_attention( + state["MultiHeadDotProductAttention_0"]["key"]["kernel"] + ), + prefix + "attention.attention.value.bias": transform_attention( + state["MultiHeadDotProductAttention_0"]["value"]["bias"] + ), + prefix + "attention.attention.value.weight": transform_attention( + state["MultiHeadDotProductAttention_0"]["value"]["kernel"] + ), + prefix + "attention.output.dense.bias": state["MultiHeadDotProductAttention_0"]["out"]["bias"], + prefix + "attention.output.dense.weight": transform_attention_output_weight( + state["MultiHeadDotProductAttention_0"]["out"]["kernel"] + ), + } + + return new_state + + +def get_n_layers(state_dict): + return sum([1 if "encoderblock_" in k else 0 for k in state_dict["optimizer"]["target"]["Transformer"].keys()]) + + +def transform_state(state_dict, classification_head=False): + transformer_layers = get_n_layers(state_dict) + + new_state = OrderedDict() + + new_state["layernorm.bias"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["bias"] + new_state["layernorm.weight"] = state_dict["optimizer"]["target"]["Transformer"]["encoder_norm"]["scale"] + + new_state["embeddings.patch_embeddings.projection.weight"] = np.transpose( + state_dict["optimizer"]["target"]["embedding"]["kernel"], (4, 3, 0, 1, 2) + ) + new_state["embeddings.patch_embeddings.projection.bias"] = state_dict["optimizer"]["target"]["embedding"]["bias"] + + new_state["embeddings.cls_token"] = state_dict["optimizer"]["target"]["cls"] + new_state["embeddings.position_embeddings"] = state_dict["optimizer"]["target"]["Transformer"]["posembed_input"][ + "pos_embedding" + ] + + for i in range(transformer_layers): + new_state.update(transform_state_encoder_block(state_dict, i)) + + if classification_head: + new_state = {"vivit." + k: v for k, v in new_state.items()} + new_state["classifier.weight"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["kernel"]) + new_state["classifier.bias"] = np.transpose(state_dict["optimizer"]["target"]["output_projection"]["bias"]) + + return {k: torch.tensor(v) for k, v in new_state.items()} + + +# checks that image processor settings are the same as in the original implementation +# original: https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/data/video_tfrecord_dataset.py +# dataset specific config: +# https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/configs/kinetics400/vivit_base_k400.py +def get_processor() -> VivitImageProcessor: + extractor = VivitImageProcessor() + + assert extractor.do_resize is True + assert extractor.size == {"shortest_edge": 256} + assert extractor.do_center_crop is True + assert extractor.crop_size == {"width": 224, "height": 224} + assert extractor.resample == PILImageResampling.BILINEAR + + # here: https://github.com/deepmind/dmvr/blob/master/dmvr/modalities.py + # one can seen that add_image has default values for normalization_mean and normalization_std set to 0 and 1 + # which effectively means no normalization (and ViViT does not overwrite those when calling this func) + assert extractor.do_normalize is False + assert extractor.do_rescale is True + assert extractor.rescale_factor == 1 / 255 + + # zero-centering = True in original implementation + assert extractor.do_zero_centering is True + + return extractor + + +def convert(output_path: str): + flax_model_path = "checkpoint" + + if not os.path.exists(flax_model_path): + download_checkpoint(flax_model_path) + + state_dict = restore_checkpoint(flax_model_path, None) + new_state = transform_state(state_dict, classification_head=True) + + config = get_vivit_config() + + assert config.image_size == 224 + assert config.num_frames == 32 + + model = VivitForVideoClassification(config) + model.load_state_dict(new_state) + model.eval() + + extractor = get_processor() + + video = prepare_video() + inputs = extractor(video, return_tensors="pt") + + outputs = model(**inputs) + + expected_shape = torch.Size([1, 400]) + expected_slice = torch.tensor([-1.0543, 2.0764, -0.2104, 0.4439, -0.9658]) + + assert outputs.logits.shape == expected_shape + assert torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4), outputs.logits[0, :5] + + model.save_pretrained(output_path) + extractor.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--output_model_name", "-o", type=str, help="Output path for the converted HuggingFace model") + + args = parser.parse_args() + convert(args.output_model_name) diff --git a/docs/transformers/build/lib/transformers/models/vivit/image_processing_vivit.py b/docs/transformers/build/lib/transformers/models/vivit/image_processing_vivit.py new file mode 100644 index 0000000000000000000000000000000000000000..8b369be41ba883b0d9a853a043a7a17417024fc4 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vivit/image_processing_vivit.py @@ -0,0 +1,407 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for Vivit.""" + +from typing import Dict, List, Optional, Union + +import numpy as np + +from transformers.utils import is_vision_available +from transformers.utils.generic import TensorType + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import ( + get_resize_output_image_size, + rescale, + resize, + to_channel_dimension_format, +) +from ...image_utils import ( + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + infer_channel_dimension_format, + is_scaled_image, + is_valid_image, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import filter_out_non_signature_kwargs, logging + + +if is_vision_available(): + import PIL + +logger = logging.get_logger(__name__) + + +def make_batched(videos) -> List[List[ImageInput]]: + if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]): + return videos + + elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]): + return [videos] + + elif is_valid_image(videos): + return [[videos]] + + raise ValueError(f"Could not make batched video from {videos}") + + +class VivitImageProcessor(BaseImageProcessor): + r""" + Constructs a Vivit image processor. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the + `do_resize` parameter in the `preprocess` method. + size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 256}`): + Size of the output image after resizing. The shortest edge of the image will be resized to + `size["shortest_edge"]` while maintaining the aspect ratio of the original image. Can be overriden by + `size` in the `preprocess` method. + resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): + Resampling filter to use if resizing the image. Can be overridden by the `resample` parameter in the + `preprocess` method. + do_center_crop (`bool`, *optional*, defaults to `True`): + Whether to center crop the image to the specified `crop_size`. Can be overridden by the `do_center_crop` + parameter in the `preprocess` method. + crop_size (`Dict[str, int]`, *optional*, defaults to `{"height": 224, "width": 224}`): + Size of the image after applying the center crop. Can be overridden by the `crop_size` parameter in the + `preprocess` method. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the `do_rescale` + parameter in the `preprocess` method. + rescale_factor (`int` or `float`, *optional*, defaults to `1/127.5`): + Defines the scale factor to use if rescaling the image. Can be overridden by the `rescale_factor` parameter + in the `preprocess` method. + offset (`bool`, *optional*, defaults to `True`): + Whether to scale the image in both negative and positive directions. Can be overriden by the `offset` in + the `preprocess` method. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` + method. + image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`): + Mean to use if normalizing the image. This is a float or list of floats the length of the number of + channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. + image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`): + Standard deviation to use if normalizing the image. This is a float or list of floats the length of the + number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Dict[str, int] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_center_crop: bool = True, + crop_size: Dict[str, int] = None, + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 127.5, + offset: bool = True, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + size = size if size is not None else {"shortest_edge": 256} + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224} + crop_size = get_size_dict(crop_size, param_name="crop_size") + + self.do_resize = do_resize + self.size = size + self.do_center_crop = do_center_crop + self.crop_size = crop_size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.offset = offset + self.do_normalize = do_normalize + self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD + + def resize( + self, + image: np.ndarray, + size: Dict[str, int], + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. + + Args: + image (`np.ndarray`): + Image to resize. + size (`Dict[str, int]`): + Size of the output image. If `size` is of the form `{"height": h, "width": w}`, the output image will + have the size `(h, w)`. If `size` is of the form `{"shortest_edge": s}`, the output image will have its + shortest edge of length `s` while keeping the aspect ratio of the original image. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BILINEAR`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + size = get_size_dict(size, default_to_square=False) + if "shortest_edge" in size: + output_size = get_resize_output_image_size( + image, size["shortest_edge"], default_to_square=False, input_data_format=input_data_format + ) + elif "height" in size and "width" in size: + output_size = (size["height"], size["width"]) + else: + raise ValueError(f"Size must have 'height' and 'width' or 'shortest_edge' as keys. Got {size.keys()}") + return resize( + image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + **kwargs, + ) + + # Copied from transformers.models.efficientnet.image_processing_efficientnet.EfficientNetImageProcessor.rescale + def rescale( + self, + image: np.ndarray, + scale: Union[int, float], + offset: bool = True, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ): + """ + Rescale an image by a scale factor. + + If `offset` is `True`, the image has its values rescaled by `scale` and then offset by 1. If `scale` is + 1/127.5, the image is rescaled between [-1, 1]. + image = image * scale - 1 + + If `offset` is `False`, and `scale` is 1/255, the image is rescaled between [0, 1]. + image = image * scale + + Args: + image (`np.ndarray`): + Image to rescale. + scale (`int` or `float`): + Scale to apply to the image. + offset (`bool`, *optional*): + Whether to scale the image in both negative and positive directions. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + rescaled_image = rescale( + image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs + ) + + if offset: + rescaled_image = rescaled_image - 1 + + return rescaled_image + + def _preprocess_image( + self, + image: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + offset: Optional[bool] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single image.""" + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_resize=do_resize, + size=size, + resample=resample, + ) + + if offset and not do_rescale: + raise ValueError("For offset, do_rescale must also be set to True.") + + # All transformations expect numpy arrays. + image = to_numpy_array(image) + + if do_rescale and is_scaled_image(image): + logger.warning_once( + "It looks like you are trying to rescale already rescaled images. If the input" + " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." + ) + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(image) + + if do_resize: + image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) + + if do_center_crop: + image = self.center_crop(image, size=crop_size, input_data_format=input_data_format) + + if do_rescale: + image = self.rescale(image=image, scale=rescale_factor, offset=offset, input_data_format=input_data_format) + + if do_normalize: + image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) + + image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) + return image + + @filter_out_non_signature_kwargs() + def preprocess( + self, + videos: ImageInput, + do_resize: Optional[bool] = None, + size: Dict[str, int] = None, + resample: PILImageResampling = None, + do_center_crop: Optional[bool] = None, + crop_size: Dict[str, int] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + offset: Optional[bool] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: ChannelDimension = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> PIL.Image.Image: + """ + Preprocess an image or batch of images. + + Args: + videos (`ImageInput`): + Video frames to preprocess. Expects a single or batch of video frames with pixel values ranging from 0 + to 255. If passing in frames with pixel values between 0 and 1, set `do_rescale=False`. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the image. + size (`Dict[str, int]`, *optional*, defaults to `self.size`): + Size of the image after applying resize. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`, Only + has an effect if `do_resize` is set to `True`. + do_center_crop (`bool`, *optional*, defaults to `self.do_centre_crop`): + Whether to centre crop the image. + crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`): + Size of the image after applying the centre crop. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the image values between `[-1 - 1]` if `offset` is `True`, `[0, 1]` otherwise. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Rescale factor to rescale the image by if `do_rescale` is set to `True`. + offset (`bool`, *optional*, defaults to `self.offset`): + Whether to scale the image in both negative and positive directions. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the image. + image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): + Image mean. + image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): + Image standard deviation. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + The channel dimension format for the output image. Can be one of: + - `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - Unset: Use the inferred channel dimension format of the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format for the input image. If unset, the channel dimension format is inferred + from the input image. Can be one of: + - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. + - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. + - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. + """ + do_resize = do_resize if do_resize is not None else self.do_resize + resample = resample if resample is not None else self.resample + do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + offset = offset if offset is not None else self.offset + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + crop_size = crop_size if crop_size is not None else self.crop_size + crop_size = get_size_dict(crop_size, param_name="crop_size") + + if not valid_images(videos): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + videos = make_batched(videos) + + videos = [ + [ + self._preprocess_image( + image=img, + do_resize=do_resize, + size=size, + resample=resample, + do_center_crop=do_center_crop, + crop_size=crop_size, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + offset=offset, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + for img in video + ] + for video in videos + ] + + data = {"pixel_values": videos} + return BatchFeature(data=data, tensor_type=return_tensors) + + +__all__ = ["VivitImageProcessor"] diff --git a/docs/transformers/build/lib/transformers/models/vivit/modeling_vivit.py b/docs/transformers/build/lib/transformers/models/vivit/modeling_vivit.py new file mode 100644 index 0000000000000000000000000000000000000000..669106239a06a83cb1f4737f76eb7a5bed474227 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/vivit/modeling_vivit.py @@ -0,0 +1,844 @@ +# coding=utf-8 +# Copyright 2023 Google AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch ViViT model.""" + +from typing import Callable, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, + torch_int, +) +from .configuration_vivit import VivitConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/vivit-b-16x2-kinetics400" +_CONFIG_FOR_DOC = "VivitConfig" + + +class VivitTubeletEmbeddings(nn.Module): + """ + Construct Vivit Tubelet embeddings. + + This module turns a batch of videos of shape (batch_size, num_frames, num_channels, height, width) into a tensor of + shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder. + + The seq_len (the number of patches) equals (number of frames // tubelet_size[0]) * (height // tubelet_size[1]) * + (width // tubelet_size[2]). + """ + + def __init__(self, config): + super().__init__() + self.num_frames = config.num_frames + self.image_size = config.image_size + self.patch_size = config.tubelet_size + self.num_patches = ( + (self.image_size // self.patch_size[2]) + * (self.image_size // self.patch_size[1]) + * (self.num_frames // self.patch_size[0]) + ) + self.embed_dim = config.hidden_size + + self.projection = nn.Conv3d( + config.num_channels, config.hidden_size, kernel_size=config.tubelet_size, stride=config.tubelet_size + ) + + def forward(self, pixel_values, interpolate_pos_encoding: bool = False): + batch_size, num_frames, num_channels, height, width = pixel_values.shape + if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): + raise ValueError( + f"Image image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})." + ) + + # permute to (batch_size, num_channels, num_frames, height, width) + pixel_values = pixel_values.permute(0, 2, 1, 3, 4) + + x = self.projection(pixel_values) + # out_batch_size, out_num_channels, out_num_frames, out_height, out_width = x.shape + # flattens time and space dimensions, transposes to (out_batch_size, flat_tokens, out_num_channels) + x = x.flatten(2).transpose(1, 2) + return x + + +class VivitEmbeddings(nn.Module): + """ + Vivit Embeddings. + + Creates embeddings from a video using VivitTubeletEmbeddings, adds CLS token and positional embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.patch_embeddings = VivitTubeletEmbeddings(config) + + self.position_embeddings = nn.Parameter( + torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size) + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.patch_size = config.tubelet_size[1:] + self.config = config + + # Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution + images. This method is also adapted to support torch.jit tracing. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and + - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + + # always interpolate when tracing to ensure the exported model works for dynamic input shapes + if not torch.jit.is_tracing() and num_patches == num_positions and height == width: + return self.position_embeddings + + class_pos_embed = self.position_embeddings[:, :1] + patch_pos_embed = self.position_embeddings[:, 1:] + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size[0] + new_width = width // self.patch_size[1] + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + + return torch.cat((class_pos_embed, patch_pos_embed), dim=1) + + def forward(self, pixel_values, interpolate_pos_encoding: bool = False): + batch_size, num_frames, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + cls_tokens = self.cls_token.tile([batch_size, 1, 1]) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +# Copied from transformers.models.vit.modeling_vit.eager_attention_forward +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + # Take the dot product between "query" and "key" to get the raw attention scores. + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + + # Normalize the attention scores to probabilities. + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + # Mask heads if we want to + if attention_mask is not None: + attn_weights = attn_weights * attention_mask + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfAttention with ViT->Vivit +class VivitSelfAttention(nn.Module): + def __init__(self, config: VivitConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.config = config + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dropout_prob = config.attention_probs_dropout_prob + self.scaling = self.attention_head_size**-0.5 + self.is_causal = False + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + context_layer, attention_probs = attention_interface( + self, + query_layer, + key_layer, + value_layer, + head_mask, + is_causal=self.is_causal, + scaling=self.scaling, + dropout=0.0 if not self.training else self.dropout_prob, + ) + + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit +class VivitSelfOutput(nn.Module): + """ + The residual connection is defined in VivitLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: VivitConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +# Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->Vivit +class VivitAttention(nn.Module): + def __init__(self, config: VivitConfig) -> None: + super().__init__() + self.attention = VivitSelfAttention(config) + self.output = VivitSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class VivitIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class VivitOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +class VivitLayer(nn.Module): + """This corresponds to the EncoderBlock class in the scenic/vivit implementation.""" + + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = VivitAttention(config) + self.intermediate = VivitIntermediate(config) + self.output = VivitOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, head_mask=None, output_attentions=False): + self_attention_outputs = self.attention( + # in Vivit, layernorm is applied before self-attention + self.layernorm_before(hidden_states), + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] + + # first residual connection + hidden_states = attention_output + hidden_states + + # in Vivit, layernorm is also applied after self-attention + layer_output = self.layernorm_after(hidden_states) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class VivitEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([VivitLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class VivitPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class VivitPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = VivitConfig + base_model_prefix = "vivit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = [] + _supports_sdpa = True + _supports_flash_attn_2 = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv3d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, VivitEmbeddings): + module.cls_token.data.zero_() + module.position_embeddings.data.zero_() + + +VIVIT_START_DOCSTRING = r""" + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it + as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`VivitConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +VIVIT_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + Pixel values. Pixel values can be obtained using [`VivitImageProcessor`]. See + [`VivitImageProcessor.preprocess`] for details. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + interpolate_pos_encoding (`bool`, *optional*, `False`): + Whether to interpolate the pre-trained position encodings. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare ViViT Transformer model outputting raw hidden-states without any specific head on top.", + VIVIT_START_DOCSTRING, +) +class VivitModel(VivitPreTrainedModel): + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = VivitEmbeddings(config) + self.encoder = VivitEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = VivitPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. + + Args: + heads_to_prune: + dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]: + r""" + Returns: + + Examples: + + ```python + >>> import av + >>> import numpy as np + + >>> from transformers import VivitImageProcessor, VivitModel + >>> from huggingface_hub import hf_hub_download + + >>> np.random.seed(0) + + + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' + ... converted_len = int(clip_len * frame_sample_rate) + ... end_idx = np.random.randint(converted_len, seg_len) + ... start_idx = end_idx - converted_len + ... indices = np.linspace(start_idx, end_idx, num=clip_len) + ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) + ... return indices + + + >>> # video clip consists of 300 frames (10 seconds at 30 FPS) + >>> file_path = hf_hub_download( + ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" + ... ) + >>> container = av.open(file_path) + + >>> # sample 32 frames + >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container=container, indices=indices) + + >>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400") + >>> model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400") + + >>> # prepare video for the model + >>> inputs = image_processor(list(video), return_tensors="pt") + + >>> # forward pass + >>> outputs = model(**inputs) + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 3137, 768] + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + encoder_outputs = self.encoder( + embedding_output, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """ + ViViT Transformer model with a video classification head on top (a linear layer on top of the final hidden state of the +[CLS] token) e.g. for Kinetics-400. + + + + Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by + setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained + position embeddings to the higher resolution. + + + """, + VIVIT_START_DOCSTRING, +) +class VivitForVideoClassification(VivitPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.num_labels = config.num_labels + self.vivit = VivitModel(config, add_pooling_layer=False) + + # Classifier head + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: bool = False, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + + Returns: + + Examples: + + ```python + >>> import av + >>> import numpy as np + >>> import torch + + >>> from transformers import VivitImageProcessor, VivitForVideoClassification + >>> from huggingface_hub import hf_hub_download + + >>> np.random.seed(0) + + + >>> def read_video_pyav(container, indices): + ... ''' + ... Decode the video with PyAV decoder. + ... Args: + ... container (`av.container.input.InputContainer`): PyAV container. + ... indices (`List[int]`): List of frame indices to decode. + ... Returns: + ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). + ... ''' + ... frames = [] + ... container.seek(0) + ... start_index = indices[0] + ... end_index = indices[-1] + ... for i, frame in enumerate(container.decode(video=0)): + ... if i > end_index: + ... break + ... if i >= start_index and i in indices: + ... frames.append(frame) + ... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) + + + >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len): + ... ''' + ... Sample a given number of frame indices from the video. + ... Args: + ... clip_len (`int`): Total number of frames to sample. + ... frame_sample_rate (`int`): Sample every n-th frame. + ... seg_len (`int`): Maximum allowed index of sample's last frame. + ... Returns: + ... indices (`List[int]`): List of sampled frame indices + ... ''' + ... converted_len = int(clip_len * frame_sample_rate) + ... end_idx = np.random.randint(converted_len, seg_len) + ... start_idx = end_idx - converted_len + ... indices = np.linspace(start_idx, end_idx, num=clip_len) + ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64) + ... return indices + + + >>> # video clip consists of 300 frames (10 seconds at 30 FPS) + >>> file_path = hf_hub_download( + ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" + ... ) + >>> container = av.open(file_path) + + >>> # sample 32 frames + >>> indices = sample_frame_indices(clip_len=32, frame_sample_rate=4, seg_len=container.streams.video[0].frames) + >>> video = read_video_pyav(container=container, indices=indices) + + >>> image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400") + >>> model = VivitForVideoClassification.from_pretrained("google/vivit-b-16x2-kinetics400") + + >>> inputs = image_processor(list(video), return_tensors="pt") + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + ... logits = outputs.logits + + >>> # model predicts one of the 400 Kinetics-400 classes + >>> predicted_label = logits.argmax(-1).item() + >>> print(model.config.id2label[predicted_label]) + LABEL_116 + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.vivit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["VivitModel", "VivitPreTrainedModel", "VivitForVideoClassification"] diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2/__init__.py b/docs/transformers/build/lib/transformers/models/wav2vec2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3516b478194df713d15941c5d646fbdba987bfb0 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_wav2vec2 import * + from .feature_extraction_wav2vec2 import * + from .modeling_flax_wav2vec2 import * + from .modeling_tf_wav2vec2 import * + from .modeling_wav2vec2 import * + from .processing_wav2vec2 import * + from .tokenization_wav2vec2 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2/configuration_wav2vec2.py b/docs/transformers/build/lib/transformers/models/wav2vec2/configuration_wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..c28aa6305f855c796a4ffaa565798daeb5d66028 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -0,0 +1,347 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wav2Vec2 model configuration""" + +import functools +import operator + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Wav2Vec2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Wav2Vec2Model`]. It is used to instantiate an + Wav2Vec2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Wav2Vec2 + [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32): + Vocabulary size of the Wav2Vec2 model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`Wav2Vec2Model`] or [`TFWav2Vec2Model`]. Vocabulary size of the + model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward + method of [`Wav2Vec2Model`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + hidden_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + activation_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for activations inside the fully connected layer. + attention_dropout (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + final_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the final projection layer of [`Wav2Vec2ForCTC`]. + layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more + details. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + feat_extract_norm (`str`, *optional*, defaults to `"group"`): + The norm to be applied to 1D convolutional layers in feature encoder. One of `"group"` for group + normalization of only the first 1D convolutional layer or `"layer"` for layer normalization of all 1D + convolutional layers. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for output of the feature encoder. + feat_extract_activation (`str, `optional`, defaults to `"gelu"`): + The non-linear activation function (function or string) in the 1D convolutional layers of the feature + extractor. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported. + feat_quantizer_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for quantized feature encoder states. + conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`): + A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length + of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*. + conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The + length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + num_conv_pos_embeddings (`int`, *optional*, defaults to 128): + Number of convolutional positional embeddings. Defines the kernel size of 1D convolutional positional + embeddings layer. + num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16): + Number of groups of 1D convolutional positional embeddings layer. + do_stable_layer_norm (`bool`, *optional*, defaults to `False`): + Whether to apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is + True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is + False` corresponds to applying layer norm after the attention layer. + apply_spec_augment (`bool`, *optional*, defaults to `True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see + [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://arxiv.org/abs/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procecure generates ''mask_time_prob*len(time_axis)/mask_time_length'' independent masks over the axis. If + reasoning from the propability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2),: + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if ''mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks'' + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procecure generates ''mask_feature_prob*len(feature_axis)/mask_time_length'' independent masks over + the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0),: + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + ''mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks'' + num_codevectors_per_group (`int`, *optional*, defaults to 320): + Number of entries in each quantization codebook (group). + num_codevector_groups (`int`, *optional*, defaults to 2): + Number of codevector groups for product codevector quantization. + contrastive_logits_temperature (`float`, *optional*, defaults to 0.1): + The temperature *kappa* in the contrastive loss. + feat_quantizer_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the output of the feature encoder that's used by the quantizer. + num_negatives (`int`, *optional*, defaults to 100): + Number of negative samples for the contrastive loss. + codevector_dim (`int`, *optional*, defaults to 256): + Dimensionality of the quantized feature vectors. + proj_codevector_dim (`int`, *optional*, defaults to 256): + Dimensionality of the final projection of both the quantized and the transformer features. + diversity_loss_weight (`int`, *optional*, defaults to 0.1): + The weight of the codebook diversity loss component. + ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an + instance of [`Wav2Vec2ForCTC`]. + ctc_zero_infinity (`bool`, *optional*, defaults to `False`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly + occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance + of [`Wav2Vec2ForCTC`]. + use_weighted_layer_sum (`bool`, *optional*, defaults to `False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of [`Wav2Vec2ForSequenceClassification`]. + classifier_proj_size (`int`, *optional*, defaults to 256): + Dimensionality of the projection before token mean-pooling for classification. + tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`): + A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN* + module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers. + tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the + *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*. + tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`): + A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the + *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*. + xvector_output_dim (`int`, *optional*, defaults to 512): + Dimensionality of the *XVector* embedding vectors. + add_adapter (`bool`, *optional*, defaults to `False`): + Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for + warm-starting Wav2Vec2 for SpeechEncoderDecoder models. + adapter_kernel_size (`int`, *optional*, defaults to 3): + Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adapter_stride (`int`, *optional*, defaults to 2): + Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + num_adapter_layers (`int`, *optional*, defaults to 3): + Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is + True`. + adapter_attn_dim (`int`, *optional*): + Dimension of the attention adapter weights to be used in each attention block. An example of a model using + attention adapters is [facebook/mms-1b-all](https://huggingface.co/facebook/mms-1b-all). + output_hidden_size (`int`, *optional*): + Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant + if `add_adapter is True`. + + Example: + + ```python + >>> from transformers import Wav2Vec2Config, Wav2Vec2Model + + >>> # Initializing a Wav2Vec2 facebook/wav2vec2-base-960h style configuration + >>> configuration = Wav2Vec2Config() + + >>> # Initializing a model (with random weights) from the facebook/wav2vec2-base-960h style configuration + >>> model = Wav2Vec2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "wav2vec2" + + def __init__( + self, + vocab_size=32, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout=0.1, + activation_dropout=0.1, + attention_dropout=0.1, + feat_proj_dropout=0.0, + feat_quantizer_dropout=0.0, + final_dropout=0.1, + layerdrop=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + feat_extract_norm="group", + feat_extract_activation="gelu", + conv_dim=(512, 512, 512, 512, 512, 512, 512), + conv_stride=(5, 2, 2, 2, 2, 2, 2), + conv_kernel=(10, 3, 3, 3, 3, 2, 2), + conv_bias=False, + num_conv_pos_embeddings=128, + num_conv_pos_embedding_groups=16, + do_stable_layer_norm=False, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + num_codevectors_per_group=320, + num_codevector_groups=2, + contrastive_logits_temperature=0.1, + num_negatives=100, + codevector_dim=256, + proj_codevector_dim=256, + diversity_loss_weight=0.1, + ctc_loss_reduction="sum", + ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=256, + tdnn_dim=(512, 512, 512, 512, 1500), + tdnn_kernel=(5, 3, 3, 1, 1), + tdnn_dilation=(1, 2, 3, 1, 1), + xvector_output_dim=512, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + add_adapter=False, + adapter_kernel_size=3, + adapter_stride=2, + num_adapter_layers=3, + output_hidden_size=None, + adapter_attn_dim=None, + **kwargs, + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.hidden_size = hidden_size + self.feat_extract_norm = feat_extract_norm + self.feat_extract_activation = feat_extract_activation + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_conv_pos_embeddings = num_conv_pos_embeddings + self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups + self.num_feat_extract_layers = len(self.conv_dim) + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_attention_heads = num_attention_heads + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.vocab_size = vocab_size + self.do_stable_layer_norm = do_stable_layer_norm + self.use_weighted_layer_sum = use_weighted_layer_sum + + if ( + (len(self.conv_stride) != self.num_feat_extract_layers) + or (len(self.conv_kernel) != self.num_feat_extract_layers) + or (len(self.conv_dim) != self.num_feat_extract_layers) + ): + raise ValueError( + "Configuration for convolutional layers is incorrect. It is required that `len(config.conv_dim)` ==" + " `len(config.conv_stride)` == `len(config.conv_kernel)`, but is `len(config.conv_dim) =" + f" {len(self.conv_dim)}`, `len(config.conv_stride) = {len(self.conv_stride)}`," + f" `len(config.conv_kernel) = {len(self.conv_kernel)}`." + ) + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + # parameters for pretraining with codevector quantized representations + self.num_codevectors_per_group = num_codevectors_per_group + self.num_codevector_groups = num_codevector_groups + self.contrastive_logits_temperature = contrastive_logits_temperature + self.feat_quantizer_dropout = feat_quantizer_dropout + self.num_negatives = num_negatives + self.codevector_dim = codevector_dim + self.proj_codevector_dim = proj_codevector_dim + self.diversity_loss_weight = diversity_loss_weight + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity + + # adapter + self.add_adapter = add_adapter + self.adapter_kernel_size = adapter_kernel_size + self.adapter_stride = adapter_stride + self.num_adapter_layers = num_adapter_layers + self.output_hidden_size = output_hidden_size or hidden_size + self.adapter_attn_dim = adapter_attn_dim + + # SequenceClassification-specific parameter. Feel free to ignore for other classes. + self.classifier_proj_size = classifier_proj_size + + # XVector-specific parameters. Feel free to ignore for other classes. + self.tdnn_dim = list(tdnn_dim) + self.tdnn_kernel = list(tdnn_kernel) + self.tdnn_dilation = list(tdnn_dilation) + self.xvector_output_dim = xvector_output_dim + + @property + def inputs_to_logits_ratio(self): + return functools.reduce(operator.mul, self.conv_stride, 1) + + +__all__ = ["Wav2Vec2Config"] diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..5613f83a86b4e79450a4068e0152df649565bfd2 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2/convert_wav2vec2_original_pytorch_checkpoint_to_pytorch.py @@ -0,0 +1,385 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Wav2Vec2 checkpoint.""" + +import argparse +import json +import os + +import fairseq +import torch +from fairseq.data import Dictionary + +from transformers import ( + Wav2Vec2Config, + Wav2Vec2CTCTokenizer, + Wav2Vec2FeatureExtractor, + Wav2Vec2ForCTC, + Wav2Vec2ForPreTraining, + Wav2Vec2Processor, + logging, +) +from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2ForSequenceClassification + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +MAPPING = { + "post_extract_proj": "feature_projection.projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "self_attn.k_proj": "encoder.layers.*.attention.k_proj", + "self_attn.v_proj": "encoder.layers.*.attention.v_proj", + "self_attn.q_proj": "encoder.layers.*.attention.q_proj", + "self_attn.out_proj": "encoder.layers.*.attention.out_proj", + "self_attn_layer_norm": "encoder.layers.*.layer_norm", + "fc1": "encoder.layers.*.feed_forward.intermediate_dense", + "fc2": "encoder.layers.*.feed_forward.output_dense", + "final_layer_norm": "encoder.layers.*.final_layer_norm", + "encoder.layer_norm": "encoder.layer_norm", + "adapter_layer": "encoder.layers.*.adapter_layer", + "w2v_model.layer_norm": "feature_projection.layer_norm", + "quantizer.weight_proj": "quantizer.weight_proj", + "quantizer.vars": "quantizer.codevectors", + "project_q": "project_q", + "final_proj": "project_hid", + "w2v_encoder.proj": "lm_head", + "mask_emb": "masked_spec_embed", + "pooling_layer.linear": "projector", + "pooling_layer.projection": "classifier", +} +TOP_LEVEL_KEYS = [ + "lm_head", + "quantizer.weight_proj", + "quantizer.codevectors", + "project_q", + "project_hid", + "projector", + "classifier", +] + + +def read_txt_into_dict(filename): + result = {} + with open(filename, "r") as file: + for line_number, line in enumerate(file): + line = line.strip() + if line: + words = line.split() + key = line_number + value = words[0] + result[key] = value + return result + + +def set_recursively(key, value, full_name, weight_type, hf_pointer): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + hf_param_name = None + for param_key in PARAM_MAPPING.keys(): + if full_name.endswith(param_key): + hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]] + weight_type = "param" + + # fairseq uses nn.utils.weight_norm() while transformers switches to nn.utils.parametrizations.weight_norm() + # the mapping between two versions: + # https://github.com/pytorch/pytorch/blob/56935684c3dfad7841c83c719eeebecb560fe466/torch/nn/utils/parametrizations.py#L389-L395 + + if weight_type is not None and weight_type != "param": + if weight_type == "weight_g" and not hasattr(hf_pointer, "weight_g"): + hf_shape = hf_pointer.parametrizations.weight.original0.shape + elif weight_type == "weight_v" and not hasattr(hf_pointer, "weight_v"): + hf_shape = hf_pointer.parametrizations.weight.original1.shape + else: + hf_shape = getattr(hf_pointer, weight_type).shape + elif weight_type is not None and weight_type == "param": + shape_pointer = hf_pointer + for attribute in hf_param_name.split("."): + shape_pointer = getattr(shape_pointer, attribute) + hf_shape = shape_pointer.shape + + # let's reduce dimension + value = value[0] + else: + hf_shape = hf_pointer.shape + + if hf_shape != value.shape: + raise ValueError( + f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be" + f" {value.shape} for {full_name}" + ) + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + if hasattr(hf_pointer, "weight_g"): + hf_pointer.weight_g.data = value + else: + hf_pointer.parametrizations.weight.original0.data = value + elif weight_type == "weight_v": + if hasattr(hf_pointer, "weight_v"): + hf_pointer.weight_v.data = value + else: + hf_pointer.parametrizations.weight.original1.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + elif weight_type == "param": + for attribute in hf_param_name.split("."): + hf_pointer = getattr(hf_pointer, attribute) + hf_pointer.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") + + +def rename_dict(key, value, full_name, weight_type, hf_dict): + hf_param_name = None + for param_key in PARAM_MAPPING.keys(): + if full_name.endswith(param_key): + hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]] + weight_type = "param" + + if weight_type is not None and weight_type != "param": + full_key = ".".join([key, weight_type]) + elif weight_type is not None and weight_type == "param": + full_key = ".".join([key, hf_param_name]) + else: + full_key = key + + hf_dict[full_key] = value if "lm_head" in full_key else value[0] + + +PARAM_MAPPING = { + "W_a": "linear_1.weight", + "W_b": "linear_2.weight", + "b_a": "linear_1.bias", + "b_b": "linear_2.bias", + "ln_W": "norm.weight", + "ln_b": "norm.bias", +} + + +def load_wav2vec2_layer(name, value, hf_model=None, hf_dict=None): + is_used = False + for key, mapped_key in MAPPING.items(): + mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "bias" in name: + weight_type = "bias" + elif "weight" in name: + # TODO: don't match quantizer.weight_proj + weight_type = "weight" + else: + weight_type = None + if hf_dict is not None: + rename_dict(mapped_key, value, name, weight_type, hf_dict) + else: + set_recursively(mapped_key, value, name, weight_type, hf_model) + return is_used + return is_used + + +def recursively_load_weights(fairseq_model, hf_model, is_headless): + unused_weights = [] + fairseq_dict = fairseq_model.state_dict() + + feature_extractor = hf_model.wav2vec2.feature_extractor + + for name, value in fairseq_dict.items(): + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_extractor, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + else: + is_used = load_wav2vec2_layer(name, value, hf_model) + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape: + raise ValueError( + f"{full_name} has size {value.shape}, but" + f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found." + ) + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +@torch.no_grad() +def convert_wav2vec2_checkpoint( + checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True, is_seq_class=False +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + if config_path is not None: + config = Wav2Vec2Config.from_pretrained(config_path) + else: + config = Wav2Vec2Config() + + if is_seq_class: + id2label = read_txt_into_dict(dict_path) + config.id2label = id2label + hf_wav2vec = Wav2Vec2ForSequenceClassification(config) + feature_extractor = Wav2Vec2FeatureExtractor( + feature_size=1, + sampling_rate=16000, + padding_value=0, + do_normalize=True, + return_attention_mask=True, + ) + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + elif is_finetuned: + if dict_path: + target_dict = Dictionary.load(dict_path) + + # important change bos & pad token id since CTC symbol is and + # not as in fairseq + config.bos_token_id = target_dict.pad_index + config.pad_token_id = target_dict.bos_index + config.eos_token_id = target_dict.eos_index + config.vocab_size = len(target_dict.symbols) + vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json") + if not os.path.isdir(pytorch_dump_folder_path): + logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path)) + return + os.makedirs(pytorch_dump_folder_path, exist_ok=True) + vocab_dict = target_dict.indices + + # fairseq has the and switched + vocab_dict[""] = 0 + vocab_dict[""] = 1 + with open(vocab_path, "w", encoding="utf-8") as vocab_handle: + json.dump(vocab_dict, vocab_handle) + tokenizer = Wav2Vec2CTCTokenizer( + vocab_path, + unk_token=target_dict.unk_word, + pad_token=target_dict.pad_word, + bos_token=target_dict.bos_word, + eos_token=target_dict.eos_word, + word_delimiter_token="|", + do_lower_case=False, + ) + return_attention_mask = True if config.feat_extract_norm == "layer" else False + feature_extractor = Wav2Vec2FeatureExtractor( + feature_size=1, + sampling_rate=16000, + padding_value=0, + do_normalize=True, + return_attention_mask=return_attention_mask, + ) + processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) + processor.save_pretrained(pytorch_dump_folder_path) + + hf_wav2vec = Wav2Vec2ForCTC(config) + else: + hf_wav2vec = Wav2Vec2ForPreTraining(config) + + if is_finetuned or is_seq_class: + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])} + ) + else: + task_arg = argparse.Namespace(task="audio_pretraining") + task = fairseq.tasks.setup_task(task_arg) + + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path], task=task) + + model = model[0].eval() + + recursively_load_weights(model, hf_wav2vec, not is_finetuned) + + hf_wav2vec.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") + parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") + parser.add_argument( + "--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not" + ) + parser.add_argument( + "--is_seq_class", + action="store_true", + help="Whether the model to convert is a fine-tuned sequence classification model or not", + ) + args = parser.parse_args() + + is_finetuned = not args.not_finetuned and not args.is_seq_class + convert_wav2vec2_checkpoint( + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.config_path, + args.dict_path, + is_finetuned, + args.is_seq_class, + ) diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py b/docs/transformers/build/lib/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..fa33416c8bdc6d0a82c50c364438861626d26879 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2/convert_wav2vec2_original_s3prl_checkpoint_to_pytorch.py @@ -0,0 +1,109 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Hubert checkpoint.""" + +import argparse + +import torch + +from transformers import ( + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + Wav2Vec2ForAudioFrameClassification, + Wav2Vec2ForSequenceClassification, + Wav2Vec2ForXVector, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +def convert_classification(base_model_name, hf_config, downstream_dict): + model = Wav2Vec2ForSequenceClassification.from_pretrained(base_model_name, config=hf_config) + model.projector.weight.data = downstream_dict["projector.weight"] + model.projector.bias.data = downstream_dict["projector.bias"] + model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"] + model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"] + return model + + +def convert_diarization(base_model_name, hf_config, downstream_dict): + model = Wav2Vec2ForAudioFrameClassification.from_pretrained(base_model_name, config=hf_config) + model.classifier.weight.data = downstream_dict["model.linear.weight"] + model.classifier.bias.data = downstream_dict["model.linear.bias"] + return model + + +def convert_xvector(base_model_name, hf_config, downstream_dict): + model = Wav2Vec2ForXVector.from_pretrained(base_model_name, config=hf_config) + model.projector.weight.data = downstream_dict["connector.weight"] + model.projector.bias.data = downstream_dict["connector.bias"] + for i, kernel_size in enumerate(hf_config.tdnn_kernel): + model.tdnn[i].kernel.weight.data = downstream_dict[ + f"model.framelevel_feature_extractor.module.{i}.kernel.weight" + ] + model.tdnn[i].kernel.bias.data = downstream_dict[f"model.framelevel_feature_extractor.module.{i}.kernel.bias"] + + model.feature_extractor.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.weight"] + model.feature_extractor.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear1.bias"] + model.classifier.weight.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.weight"] + model.classifier.bias.data = downstream_dict["model.utterancelevel_feature_extractor.linear2.bias"] + model.objective.weight.data = downstream_dict["objective.W"] + return model + + +@torch.no_grad() +def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path): + """ + Copy/paste/tweak model's weights to transformers design. + """ + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + + downstream_dict = checkpoint["Downstream"] + + hf_config = Wav2Vec2Config.from_pretrained(config_path) + hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( + base_model_name, return_attention_mask=True, do_normalize=False + ) + + arch = hf_config.architectures[0] + if arch.endswith("ForSequenceClassification"): + hf_model = convert_classification(base_model_name, hf_config, downstream_dict) + elif arch.endswith("ForAudioFrameClassification"): + hf_model = convert_diarization(base_model_name, hf_config, downstream_dict) + elif arch.endswith("ForXVector"): + hf_model = convert_xvector(base_model_name, hf_config, downstream_dict) + else: + raise NotImplementedError(f"S3PRL weights conversion is not supported for {arch}") + + if hf_config.use_weighted_layer_sum: + hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"] + + hf_feature_extractor.save_pretrained(model_dump_path) + hf_model.save_pretrained(model_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model." + ) + parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.") + parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.") + args = parser.parse_args() + convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path) diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/docs/transformers/build/lib/transformers/models/wav2vec2/feature_extraction_wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..3dde386b32af9b8b46308d0dc7c9e879fbff054a --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -0,0 +1,243 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Feature extractor class for Wav2Vec2 +""" + +from typing import List, Optional, Union + +import numpy as np + +from ...feature_extraction_sequence_utils import SequenceFeatureExtractor +from ...feature_extraction_utils import BatchFeature +from ...utils import PaddingStrategy, TensorType, logging + + +logger = logging.get_logger(__name__) + + +class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a Wav2Vec2 feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + feature_size (`int`, *optional*, defaults to 1): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + padding_value (`float`, *optional*, defaults to 0.0): + The value that is used to fill the padding values. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly + improve the performance for some models, *e.g.*, + [wav2vec2-lv60](https://huggingface.co/models?search=lv60). + return_attention_mask (`bool`, *optional*, defaults to `False`): + Whether or not [`~Wav2Vec2FeatureExtractor.__call__`] should return `attention_mask`. + + + + Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as + [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using + `attention_mask`. For such models, `input_values` should simply be padded with 0 and no `attention_mask` + should be passed. + + For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as + [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should be + passed for batched inference. + + """ + + model_input_names = ["input_values", "attention_mask"] + + def __init__( + self, + feature_size=1, + sampling_rate=16000, + padding_value=0.0, + return_attention_mask=False, + do_normalize=True, + **kwargs, + ): + super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) + self.return_attention_mask = return_attention_mask + self.do_normalize = do_normalize + + @staticmethod + def zero_mean_unit_var_norm( + input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0 + ) -> List[np.ndarray]: + """ + Every array in the list is normalized to have zero mean and unit variance + """ + if attention_mask is not None: + attention_mask = np.array(attention_mask, np.int32) + normed_input_values = [] + + for vector, length in zip(input_values, attention_mask.sum(-1)): + normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + else: + normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] + + return normed_input_values + + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + padding: Union[bool, str, PaddingStrategy] = False, + max_length: Optional[int] = None, + truncation: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + sampling_rate: Optional[int] = None, + **kwargs, + ) -> BatchFeature: + """ + Main method to featurize and prepare for the model one or several sequence(s). + + Args: + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`): + Activates truncation to cut input sequences longer than *max_length* to *max_length*. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific feature_extractor's default. + + [What are attention masks?](../glossary#attention-mask) + + + + Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as + [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using + `attention_mask`. For such models, `input_values` should simply be padded with 0 and no + `attention_mask` should be passed. + + For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as + [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should + be passed for batched inference. + + + + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + sampling_rate (`int`, *optional*): + The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass + `sampling_rate` at the forward call to prevent silent errors. + padding_value (`float`, *optional*, defaults to 0.0): + """ + + if sampling_rate is not None: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" + f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" + f" {self.sampling_rate} and not {sampling_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + # always return batch + if not is_batched: + raw_speech = [raw_speech] + + # convert into correct format for padding + encoded_inputs = BatchFeature({"input_values": raw_speech}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + # convert input values to correct format + input_values = padded_inputs["input_values"] + if not isinstance(input_values[0], np.ndarray): + padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values] + elif ( + not isinstance(input_values, np.ndarray) + and isinstance(input_values[0], np.ndarray) + and input_values[0].dtype is np.dtype(np.float64) + ): + padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values] + elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(np.float64): + padded_inputs["input_values"] = input_values.astype(np.float32) + + # convert attention_mask to correct format + attention_mask = padded_inputs.get("attention_mask") + if attention_mask is not None: + padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] + + # zero-mean and unit-variance normalization + if self.do_normalize: + attention_mask = ( + attention_mask + if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD + else None + ) + padded_inputs["input_values"] = self.zero_mean_unit_var_norm( + padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value + ) + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + return padded_inputs + + +__all__ = ["Wav2Vec2FeatureExtractor"] diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/docs/transformers/build/lib/transformers/models/wav2vec2/modeling_flax_wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..547076205018b7a15e2d7785eeea8c9cc9eed109 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -0,0 +1,1428 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Flax Wav2Vec2 model.""" + +from functools import partial +from typing import Optional, Tuple, Union + +import flax +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from .configuration_wav2vec2 import Wav2Vec2Config + + +logger = logging.get_logger(__name__) + + +@flax.struct.dataclass +class FlaxWav2Vec2BaseModelOutput(ModelOutput): + """ + Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions. + + Args: + last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`): + Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim` + being the dimension of the last convolutional layer. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: jnp.ndarray = None + extract_features: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +@flax.struct.dataclass +class FlaxWav2Vec2ForPreTrainingOutput(ModelOutput): + """ + Output type of [`FlaxWav2Vec2ForPreTrainingOutput`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when model is in train mode, `jnp.ndarray` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`jnp.ndarray` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`jnp.ndarray` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + projected_states: jnp.ndarray = None + projected_quantized_states: jnp.ndarray = None + codevector_perplexity: jnp.ndarray = None + hidden_states: Optional[Tuple[jnp.ndarray]] = None + attentions: Optional[Tuple[jnp.ndarray]] = None + + +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[np.ndarray] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + mask_prob: + probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_length: size of the mask + min_masks: minimum number of masked spans + + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and" + f" `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + np.random.rand(1).item()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + + # get random indices to mask + spec_aug_mask_idxs = np.array( + [ + np.random.choice(np.arange(sequence_length - (mask_length - 1)), num_masked_spans, replace=False) + for _ in range(batch_size) + ] + ) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to(spec_aug_mask_idxs[:, :, None], (batch_size, num_masked_spans, mask_length)) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, num_masked_spans * mask_length) + + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, num_masked_spans, mask_length)).reshape( + batch_size, num_masked_spans * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + if attention_mask is not None: + # make sure padded input ids cannot be masked + spec_aug_mask = np.where(attention_mask, spec_aug_mask, False) + + return spec_aug_mask + + +def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attention_mask: Optional[np.ndarray] = None): + """ + Sample `num_negatives` vectors from feature vectors. + """ + batch_size, sequence_length, hidden_size = features_shape + if sequence_length <= 1: + raise ValueError( + "`features should have `sequence_length` > 1, but are of shape " + f"(batch_size, sequence_length, hidden_size) = ({batch_size, sequence_length, hidden_size})." + ) + + # get `num_negatives` random vector indices from the same utterance + sampled_negative_indices = [] + for batch_idx in range(batch_size): + high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1 + sampled_indices_slice = np.random.randint(0, high, size=(num_negatives * sequence_length,)) + sampled_negative_indices.append(sampled_indices_slice) + + sampled_negative_indices = np.asarray(sampled_negative_indices, dtype=np.int32) + + # generate indices of the positive vectors themselves, repeat them `num_negatives` times + feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten() + + # avoid sampling the same positive vector, but keep the distribution uniform + sampled_negative_indices[sampled_negative_indices >= feature_indices] += 1 + + # correct for batch size + for batch_idx in range(1, batch_size): + sampled_negative_indices[batch_idx] += batch_idx * sequence_length + + return sampled_negative_indices + + +WAV2VEC2_START_DOCSTRING = r""" + Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech + Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael + Auli. + + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + + Finally, this model supports inherent JAX features such as: + + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + + +WAV2VEC2_INPUTS_DOCSTRING = r""" + Args: + input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `jnp.ndarray`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed + if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor + has `config.return_attention_mask == False`, such as + [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be + passed to avoid degraded performance when doing batched inference. For such models `input_values` should + simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly + different results depending on whether `input_values` is padded or not. + mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxWav2Vec2LayerNormConvLayer(nn.Module): + config: Wav2Vec2Config + layer_id: int = 0 + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1 + self.out_conv_dim = self.config.conv_dim[self.layer_id] + + self.conv = nn.Conv( + features=self.config.conv_dim[self.layer_id], + kernel_size=(self.config.conv_kernel[self.layer_id],), + strides=(self.config.conv_stride[self.layer_id],), + use_bias=self.config.conv_bias, + kernel_init=jax.nn.initializers.he_normal(), + padding="VALID", + dtype=self.dtype, + ) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.activation = ACT2FN[self.config.feat_extract_activation] + + def __call__(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class FlaxConvWithWeightNorm(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + features=self.config.hidden_size, + kernel_size=(self.config.num_conv_pos_embeddings,), + kernel_init=jax.nn.initializers.he_normal(), + padding="VALID", + feature_group_count=self.config.num_conv_pos_embedding_groups, + dtype=self.dtype, + ) + weight_shape = ( + self.conv.features, + self.conv.features // self.conv.feature_group_count, + self.conv.kernel_size[0], + ) + self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape) + self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]) + self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,)) + self.prev_padding = self.conv.kernel_size[0] // 2 + + def _get_normed_weights(self): + weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :] + normed_weight_v = jnp.divide(self.weight_v, weight_v_norm) + normed_kernel = jnp.multiply(normed_weight_v, self.weight_g) + return normed_kernel + + def __call__(self, hidden_states): + kernel = self._get_normed_weights() + hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0))) + hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states) + return hidden_states + + +class FlaxWav2Vec2PositionalConvEmbedding(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype) + self.activation = ACT2FN[self.config.feat_extract_activation] + self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0 + + def __call__(self, hidden_states): + hidden_states = hidden_states.transpose((0, 1, 2)) + + hidden_states = self.conv(hidden_states) + + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, : -self.num_pad_remove, :] + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose((0, 1, 2)) + return hidden_states + + +class FlaxConvLayersCollection(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + if self.config.feat_extract_norm == "layer": + self.layers = [ + FlaxWav2Vec2LayerNormConvLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) + for i in range(self.config.num_feat_extract_layers) + ] + elif self.config.feat_extract_norm == "group": + raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported") + else: + raise ValueError( + f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group'," + " 'layer']" + ) + + def __call__(self, hidden_states): + for i, conv_layer in enumerate(self.layers): + hidden_states = conv_layer(hidden_states) + return hidden_states + + +class FlaxWav2Vec2FeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype) + + def __call__(self, input_values, freeze_feature_encoder=False): + hidden_states = input_values[:, :, None] + hidden_states = self.conv_layers(hidden_states) + if freeze_feature_encoder: + hidden_states = jax.lax.stop_gradient(hidden_states) + return hidden_states + + +class FlaxWav2Vec2FeatureProjection(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.projection = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout) + + def __call__(self, hidden_states, deterministic=True): + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states, norm_hidden_states + + +class FlaxWav2Vec2Attention(nn.Module): + config: Wav2Vec2Config + embed_dim: int + num_heads: int + dropout: float = 0.0 + bias: bool = True + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() + self.out_proj = dense() + + self.dropout_layer = nn.Dropout(rate=self.dropout) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + """Input shape: Batch x Time x Channel""" + + # get query proj + query_states = self.q_proj(hidden_states) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + if attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class FlaxWav2Vec2FeedForward(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout) + + self.intermediate_dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + if isinstance(self.config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[self.config.hidden_act] + else: + self.intermediate_act_fn = self.config.hidden_act + + self.output_dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.attention = FlaxWav2Vec2Attention( + config=self.config, + embed_dim=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype) + self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights = self.attention( + hidden_states, attention_mask=attention_mask, deterministic=deterministic + ) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward( + self.final_layer_norm(hidden_states), deterministic=deterministic + ) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = [ + FlaxWav2Vec2EncoderLayerStableLayerNorm(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask=None, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxWav2Vec2StableLayerNormEncoder(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout) + self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask=None, + deterministic=True, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + if attention_mask is not None: + # make sure padded tokens are not attended to + hidden_states = jnp.where( + jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0 + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + + hidden_states = hidden_states + position_embeddings + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = self.layer_norm(outputs[0]) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_state,) + + if not return_dict: + outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions + ) + + +class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH + GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. + """ + + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.num_groups = self.config.num_codevector_groups + self.num_vars = self.config.num_codevectors_per_group + + if self.config.codevector_dim % self.num_groups != 0: + raise ValueError( + f"`config.codevector_dim {self.config.codevector_dim} must be divisible by" + f" `config.num_codevector_groups` {self.num_groups} for concatenation" + ) + + # storage for codebook variables (codewords) + self.codevectors = self.param( + "codevectors", + jax.nn.initializers.uniform(), + (1, self.num_groups * self.num_vars, self.config.codevector_dim // self.num_groups), + ) + self.weight_proj = nn.Dense( + self.num_groups * self.num_vars, + kernel_init=jax.nn.initializers.normal(1.0), + dtype=self.dtype, + ) + + @staticmethod + def _compute_perplexity(probs, mask=None): + if mask is not None: + mask_extended = jnp.broadcast_to(mask.flatten()[:, None, None], probs.shape) + probs = jnp.where(mask_extended, probs, jnp.zeros_like(probs)) + marginal_probs = probs.sum(axis=0) / mask.sum() + else: + marginal_probs = probs.mean(axis=0) + + perplexity = jnp.exp(-jnp.sum(marginal_probs * jnp.log(marginal_probs + 1e-7), axis=-1)).sum() + return perplexity + + def __call__(self, hidden_states, mask_time_indices=None, deterministic=True, temperature=1): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.reshape(batch_size * sequence_length * self.num_groups, -1) + + if not deterministic: + # sample code vector probs via gumbel in differentiateable way + gumbel_rng = self.make_rng("gumbel") + gumbels = jax.random.gumbel(gumbel_rng, hidden_states.shape) + codevector_probs = nn.softmax((hidden_states + gumbels) / temperature) + + # compute perplexity + codevector_soft_dist = nn.softmax( + hidden_states.reshape(batch_size * sequence_length, self.num_groups, -1), axis=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(axis=-1) + codevector_probs = jax.nn.one_hot(codevector_idx, hidden_states.shape[-1]) * 1.0 + codevector_probs = codevector_probs.reshape(batch_size * sequence_length, self.num_groups, -1) + perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) + + codevector_probs = codevector_probs.reshape(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = jnp.expand_dims(codevector_probs, axis=-1) * self.codevectors + codevectors = codevectors_per_group.reshape(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).reshape(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +class FlaxWav2Vec2Adapter(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + # hidden_states require down-projection if feature dims don't match + if self.config.output_hidden_size != self.config.hidden_size: + self.proj = nn.Dense( + self.config.output_hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + else: + self.proj = self.proj_layer_norm = None + + self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype) + + def __call__(self, hidden_states, deterministic=True): + # down-project hidden_states if required + if self.proj is not None and self.proj_layer_norm is not None: + hidden_states = self.proj(hidden_states) + hidden_states = self.proj_layer_norm(hidden_states) + + hidden_states = self.layers(hidden_states) + + return hidden_states + + +class FlaxWav2Vec2AdapterLayer(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.conv = nn.Conv( + features=2 * self.config.output_hidden_size, + kernel_size=(self.config.adapter_kernel_size,), + strides=(self.config.adapter_stride,), + padding=((1, 1),), + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = nn.glu(hidden_states, axis=2) + + return hidden_states + + +class FlaxWav2Vec2AdapterLayersCollection(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.layers = [ + FlaxWav2Vec2AdapterLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_adapter_layers) + ] + + def __call__(self, hidden_states): + for conv_layer in self.layers: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Wav2Vec2Config + base_model_prefix: str = "wav2vec2" + main_input_name = "input_values" + module_class: nn.Module = None + + def __init__( + self, + config: Wav2Vec2Config, + input_shape: Tuple = (1, 1024), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs, + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_values = jnp.zeros(input_shape, dtype="i4") + attention_mask = jnp.ones_like(input_values) + params_rng, dropout_rng = jax.random.split(rng, 2) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + def __call__( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + freeze_feature_encoder: bool = False, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_values.shape + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + return self.module.apply( + inputs, + jnp.array(input_values, dtype="f4"), + jnp.array(attention_mask, dtype="i4"), + mask_time_indices, + not train, + output_attentions, + output_hidden_states, + freeze_feature_encoder, + return_dict, + rngs=rngs, + ) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): + return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) + + +class FlaxWav2Vec2Module(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype) + self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype) + self.masked_spec_embed = self.param( + "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,) + ) + + if self.config.do_stable_layer_norm: + self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype) + else: + raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.") + + self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None + + def __call__( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + deterministic=True, + output_attentions=None, + output_hidden_states=None, + freeze_feature_encoder=False, + return_dict=None, + ): + extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder) + + # make sure that no loss is computed on padded inputs + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic) + if mask_time_indices is not None: # apply SpecAugment along time axis with given indices + hidden_states = jnp.where( + jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape), + jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape), + hidden_states, + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return FlaxWav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + + return input_lengths + + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None + ): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1] + + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + + batch_size = attention_mask.shape[0] + + attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype) + # these two operations makes sure that all values + # before the output lengths indices are attended to + attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) + attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") + return attention_mask + + +@add_start_docstrings( + "The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.", + WAV2VEC2_START_DOCSTRING, +) +class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel): + module_class = FlaxWav2Vec2Module + + +FLAX_WAV2VEC2_MODEL_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> from transformers import AutoProcessor, FlaxWav2Vec2Model + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-lv60") + >>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60") + + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor( + ... ds["speech"][0], sampling_rate=16_000, return_tensors="np" + ... ).input_values # Batch size 1 + >>> hidden_states = model(input_values).last_hidden_state + ``` +""" + +overwrite_call_docstring( + FlaxWav2Vec2Model, + WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_MODEL_DOCSTRING, +) +append_replace_return_docstrings( + FlaxWav2Vec2Model, output_type=FlaxWav2Vec2BaseModelOutput, config_class=Wav2Vec2Config +) + + +class FlaxWav2Vec2ForCTCModule(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.final_dropout) + self.lm_head = nn.Dense( + self.config.vocab_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + deterministic=True, + output_attentions=None, + output_hidden_states=None, + freeze_feature_encoder=False, + return_dict=None, + ): + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + mask_time_indices=mask_time_indices, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + freeze_feature_encoder=freeze_feature_encoder, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + + logits = self.lm_head(hidden_states) + + if not return_dict: + return (logits,) + outputs[2:] + + return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + def _get_feat_extract_output_lengths( + self, + input_lengths: Union[jnp.ndarray, int], + add_adapter: Optional[bool] = None, + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + + return input_lengths + + +@add_start_docstrings( + "Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).", + WAV2VEC2_START_DOCSTRING, +) +class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel): + module_class = FlaxWav2Vec2ForCTCModule + + +FLAX_WAV2VEC2_FOR_CTC_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> import jax.numpy as jnp + >>> from transformers import AutoProcessor, FlaxWav2Vec2ForCTC + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-960h-lv60") + >>> model = FlaxWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60") + + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor( + ... ds["speech"][0], sampling_rate=16_000, return_tensors="np" + ... ).input_values # Batch size 1 + >>> logits = model(input_values).logits + >>> predicted_ids = jnp.argmax(logits, axis=-1) + + >>> transcription = processor.decode(predicted_ids[0]) + >>> # should give: "A MAN SAID TO THE UNIVERSE SIR I EXIST" + ``` +""" + +overwrite_call_docstring( + FlaxWav2Vec2ForCTC, + WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_CTC_DOCSTRING, +) +append_replace_return_docstrings(FlaxWav2Vec2ForCTC, output_type=FlaxCausalLMOutput, config_class=Wav2Vec2Config) + + +class FlaxWav2Vec2ForPreTrainingModule(nn.Module): + config: Wav2Vec2Config + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) + self.dropout_features = nn.Dropout(self.config.feat_quantizer_dropout) + + self.quantizer = FlaxWav2Vec2GumbelVectorQuantizer(self.config, dtype=self.dtype) + self.project_q = nn.Dense( + self.config.proj_codevector_dim, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.project_hid = nn.Dense( + self.config.proj_codevector_dim, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + gumbel_temperature: int = 1, + deterministic: bool = True, + output_attentions=None, + output_hidden_states=None, + freeze_feature_encoder=False, + return_dict=None, + ): + r""" + Returns: + + Example: + + ```python + + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + mask_time_indices=mask_time_indices, + deterministic=deterministic, + freeze_feature_encoder=freeze_feature_encoder, + return_dict=return_dict, + ) + + # project all transformed features (including masked) to final vq dim + transformer_features = self.project_hid(outputs[0]) + + # quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1], deterministic=deterministic) + quantized_features, codevector_perplexity = self.quantizer( + extract_features, mask_time_indices, deterministic=deterministic, temperature=gumbel_temperature + ) + quantized_features = self.project_q(quantized_features) + + if not return_dict: + return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return FlaxWav2Vec2ForPreTrainingOutput( + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + + return input_lengths + + +@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV2VEC2_START_DOCSTRING) +class FlaxWav2Vec2ForPreTraining(FlaxWav2Vec2PreTrainedModel): + module_class = FlaxWav2Vec2ForPreTrainingModule + + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + # overwrite since has `gumbel_temperature` input + def __call__( + self, + input_values, + attention_mask=None, + mask_time_indices=None, + gumbel_temperature: int = 1, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + gumbel_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + freeze_feature_encoder: bool = False, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + batch_size, sequence_length = input_values.shape + + if attention_mask is None: + attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + if gumbel_rng is not None: + rngs["gumbel"] = gumbel_rng + + inputs = {"params": params or self.params} + + return self.module.apply( + inputs, + jnp.array(input_values, dtype="f4"), + jnp.array(attention_mask, dtype="i4"), + mask_time_indices, + gumbel_temperature, + not train, + output_attentions, + output_hidden_states, + freeze_feature_encoder, + return_dict, + rngs=rngs, + ) + + +FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING = """ + Returns: + + Example: + + ```python + >>> import optax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from transformers import AutoFeatureExtractor, FlaxWav2Vec2ForPreTraining + >>> from transformers.models.wav2vec2.modeling_flax_wav2vec2 import _compute_mask_indices + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-large-lv60") + >>> model = FlaxWav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-large-lv60") + + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = feature_extractor(ds["speech"][0], return_tensors="np").input_values # Batch size 1 + + >>> # compute masked indices + >>> batch_size, raw_sequence_length = input_values.shape + >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) + >>> mask_time_indices = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.2, mask_length=2) + + >>> outputs = model(input_values, mask_time_indices=mask_time_indices) + + >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states) + >>> cosine_sim = optax.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states) + + >>> # show that cosine similarity is much higher than random + >>> assert np.asarray(cosine_sim)[mask_time_indices].mean() > 0.5 + ``` +""" + +overwrite_call_docstring( + FlaxWav2Vec2ForPreTraining, + WAV2VEC2_INPUTS_DOCSTRING + FLAX_WAV2VEC2_FOR_PRETRAINING_DOCSTRING, +) +append_replace_return_docstrings( + FlaxWav2Vec2ForPreTraining, output_type=FlaxWav2Vec2ForPreTrainingOutput, config_class=Wav2Vec2Config +) + + +__all__ = ["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/docs/transformers/build/lib/transformers/models/wav2vec2/modeling_tf_wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..c385c192a987d5bd7425b6cd7c9ac6245c336c11 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -0,0 +1,1858 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorFlow Wav2Vec2 model.""" + +from __future__ import annotations + +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput, TFSequenceClassifierOutput +from ...modeling_tf_utils import ( + TFPreTrainedModel, + get_initializer, + keras, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_wav2vec2 import Wav2Vec2Config + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 2 + +_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h" +_CONFIG_FOR_DOC = "Wav2Vec2Config" + + +LARGE_NEGATIVE = -1e8 + + +@dataclass +class TFWav2Vec2BaseModelOutput(ModelOutput): + """ + Output type of [`TFWav2Vec2BaseModelOutput`], with potential hidden states and attentions. + + Args: + last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + extract_features (`tf.Tensor` of shape `(batch_size, sequence_length, conv_dim[-1])`): + Sequence of extracted feature vectors of the last convolutional layer of the model. + hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape + `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: Optional[tf.Tensor] = None + extract_features: Optional[tf.Tensor] = None + hidden_states: Tuple[tf.Tensor] | None = None + attentions: Tuple[tf.Tensor] | None = None + + +def _sample_without_replacement(distribution, num_samples): + """ + Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see + https://github.com/tensorflow/tensorflow/issues/9260 for more info + """ + z = -tf.math.log(tf.random.uniform(shape_list(distribution), 0, 1)) + _, indices = tf.nn.top_k(distribution + z, num_samples) + return indices + + +def _scatter_values_on_batch_indices(values, batch_indices, output_shape): + """ + Scatter function as in PyTorch with indices in format (batch_dim, indixes) + """ + indices_shape = shape_list(batch_indices) + # broadcast batch dim to indices_shape + broad_casted_batch_dims = tf.reshape( + tf.broadcast_to(tf.expand_dims(tf.range(indices_shape[0]), axis=-1), indices_shape), [1, -1] + ) + # transform batch_indices to pair_indices + pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0)) + # scatter values to pair indices + return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), output_shape) + + +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + min_masks: int = 0, +) -> tf.Tensor: + """ + Computes random mask spans for a given shape + + Args: + shape: the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: + probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_length: size of the mask + min_masks: minimum number of masked spans + + Adapted from [fairseq's + data_utils.py](https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376). + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + tf.debugging.assert_less( + mask_length, + sequence_length, + message=( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and" + f" `sequence_length`: {sequence_length}`" + ), + ) + + # compute number of masked spans in batch + num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,)) + num_masked_spans = tf.maximum(num_masked_spans, min_masks) + num_masked_spans = tf.cast(num_masked_spans, tf.int32) + + # make sure num masked indices <= sequence_length + num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans) + num_masked_spans = tf.squeeze(num_masked_spans) + + # SpecAugment mask to fill + spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32) + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = tf.ones((batch_size, sequence_length - (mask_length - 1))) + + # get random indices to mask + spec_aug_mask_idxs = _sample_without_replacement(uniform_dist, num_masked_spans) + + # expand masked indices to masked spans + spec_aug_mask_idxs = tf.expand_dims(spec_aug_mask_idxs, -1) + spec_aug_mask_idxs = tf.tile(spec_aug_mask_idxs, (1, 1, mask_length)) + spec_aug_mask_idxs = tf.reshape(spec_aug_mask_idxs, (batch_size, num_masked_spans * mask_length)) + + offsets = tf.range(mask_length)[tf.newaxis, tf.newaxis, :] + offsets = tf.tile(offsets, (batch_size, num_masked_spans, 1)) + offsets = tf.reshape(offsets, (batch_size, num_masked_spans * mask_length)) + + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # scatter indices to mask + spec_aug_mask = _scatter_values_on_batch_indices( + tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask) + ) + + return spec_aug_mask + + +# Copied from transformers.models.bart.modeling_tf_bart._expand_mask +def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + src_len = shape_list(mask)[1] + tgt_len = tgt_len if tgt_len is not None else src_len + one_cst = tf.constant(1.0) + mask = tf.cast(mask, dtype=one_cst.dtype) + expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1)) + + return (one_cst - expanded_mask) * LARGE_NEGATIVE + + +class TFWav2Vec2GroupNorm(keras.layers.Layer): + """ + From tensorflow-addons https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization + """ + + def __init__( + self, + groups: int = 32, + axis: int = -1, + epsilon: float = 1e-3, + center: bool = True, + scale: bool = True, + beta_initializer: keras.initializers.Initializer = "zeros", + gamma_initializer: keras.initializers.Initializer = "ones", + beta_regularizer: keras.regularizers.Regularizer = None, + gamma_regularizer: keras.regularizers.Regularizer = None, + beta_constraint: keras.constraints.Constraint = None, + gamma_constraint: keras.constraints.Constraint = None, + **kwargs, + ): + super().__init__(**kwargs) + self.supports_masking = True + self.groups = groups + self.axis = axis + self.epsilon = epsilon + self.center = center + self.scale = scale + self.beta_initializer = keras.initializers.get(beta_initializer) + self.gamma_initializer = keras.initializers.get(gamma_initializer) + self.beta_regularizer = keras.regularizers.get(beta_regularizer) + self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) + self.beta_constraint = keras.constraints.get(beta_constraint) + self.gamma_constraint = keras.constraints.get(gamma_constraint) + self._check_axis() + + def build(self, input_shape): + self._check_if_input_shape_is_none(input_shape) + self._set_number_of_groups_for_instance_norm(input_shape) + self._check_size_of_dimensions(input_shape) + self._create_input_spec(input_shape) + + self._add_gamma_weight(input_shape) + self._add_beta_weight(input_shape) + self.built = True + super().build(input_shape) + + def call(self, inputs): + input_shape = keras.backend.int_shape(inputs) + tensor_input_shape = tf.shape(inputs) + + reshaped_inputs, group_shape = self._reshape_into_groups(inputs, input_shape, tensor_input_shape) + + normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape) + + is_instance_norm = (input_shape[self.axis] // self.groups) == 1 + if not is_instance_norm: + outputs = tf.reshape(normalized_inputs, tensor_input_shape) + else: + outputs = normalized_inputs + + return outputs + + def get_config(self): + config = { + "groups": self.groups, + "axis": self.axis, + "epsilon": self.epsilon, + "center": self.center, + "scale": self.scale, + "beta_initializer": keras.initializers.serialize(self.beta_initializer), + "gamma_initializer": keras.initializers.serialize(self.gamma_initializer), + "beta_regularizer": keras.regularizers.serialize(self.beta_regularizer), + "gamma_regularizer": keras.regularizers.serialize(self.gamma_regularizer), + "beta_constraint": keras.constraints.serialize(self.beta_constraint), + "gamma_constraint": keras.constraints.serialize(self.gamma_constraint), + } + base_config = super().get_config() + return {**base_config, **config} + + def compute_output_shape(self, input_shape): + return input_shape + + def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape): + group_shape = [tensor_input_shape[i] for i in range(len(input_shape))] + is_instance_norm = (input_shape[self.axis] // self.groups) == 1 + if not is_instance_norm: + group_shape[self.axis] = input_shape[self.axis] // self.groups + group_shape.insert(self.axis, self.groups) + group_shape = tf.stack(group_shape) + reshaped_inputs = tf.reshape(inputs, group_shape) + return reshaped_inputs, group_shape + else: + return inputs, group_shape + + def _apply_normalization(self, reshaped_inputs, input_shape): + group_shape = keras.backend.int_shape(reshaped_inputs) + group_reduction_axes = list(range(1, len(group_shape))) + is_instance_norm = (input_shape[self.axis] // self.groups) == 1 + if not is_instance_norm: + axis = -2 if self.axis == -1 else self.axis - 1 + else: + axis = -1 if self.axis == -1 else self.axis - 1 + group_reduction_axes.pop(axis) + + mean, variance = tf.nn.moments(reshaped_inputs, group_reduction_axes, keepdims=True) + + gamma, beta = self._get_reshaped_weights(input_shape) + normalized_inputs = tf.nn.batch_normalization( + reshaped_inputs, + mean=mean, + variance=variance, + scale=gamma, + offset=beta, + variance_epsilon=self.epsilon, + ) + return normalized_inputs + + def _get_reshaped_weights(self, input_shape): + broadcast_shape = self._create_broadcast_shape(input_shape) + gamma = None + beta = None + if self.scale: + gamma = tf.reshape(self.gamma, broadcast_shape) + + if self.center: + beta = tf.reshape(self.beta, broadcast_shape) + return gamma, beta + + def _check_if_input_shape_is_none(self, input_shape): + dim = input_shape[self.axis] + if dim is None: + raise ValueError( + "Axis " + + str(self.axis) + + " of input tensor should have a defined dimension but the layer received an input with shape " + + str(input_shape) + + "." + ) + + def _set_number_of_groups_for_instance_norm(self, input_shape): + dim = input_shape[self.axis] + + if self.groups == -1: + self.groups = dim + + def _check_size_of_dimensions(self, input_shape): + dim = input_shape[self.axis] + if dim < self.groups: + raise ValueError( + "Number of groups (" + + str(self.groups) + + ") cannot be more than the number of channels (" + + str(dim) + + ")." + ) + + if dim % self.groups != 0: + raise ValueError( + "Number of groups (" + + str(self.groups) + + ") must be a multiple of the number of channels (" + + str(dim) + + ")." + ) + + def _check_axis(self): + if self.axis == 0: + raise ValueError( + "You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead" + ) + + def _create_input_spec(self, input_shape): + dim = input_shape[self.axis] + self.input_spec = keras.layers.InputSpec(ndim=len(input_shape), axes={self.axis: dim}) + + def _add_gamma_weight(self, input_shape): + dim = input_shape[self.axis] + shape = (dim,) + + if self.scale: + self.gamma = self.add_weight( + shape=shape, + name="gamma", + initializer=self.gamma_initializer, + regularizer=self.gamma_regularizer, + constraint=self.gamma_constraint, + ) + else: + self.gamma = None + + def _add_beta_weight(self, input_shape): + dim = input_shape[self.axis] + shape = (dim,) + + if self.center: + self.beta = self.add_weight( + shape=shape, + name="beta", + initializer=self.beta_initializer, + regularizer=self.beta_regularizer, + constraint=self.beta_constraint, + ) + else: + self.beta = None + + def _create_broadcast_shape(self, input_shape): + broadcast_shape = [1] * len(input_shape) + is_instance_norm = (input_shape[self.axis] // self.groups) == 1 + if not is_instance_norm: + broadcast_shape[self.axis] = input_shape[self.axis] // self.groups + broadcast_shape.insert(self.axis, self.groups) + else: + broadcast_shape[self.axis] = self.groups + return broadcast_shape + + +class TFWav2Vec2WeightNormConv1D(keras.layers.Conv1D): + """Adapted from https://www.tensorflow.org/probability/api_docs/python/tfp/layers/weight_norm/WeightNorm""" + + def __init__(self, filters, kernel_size, groups, explicit_padding, **kwargs): + super().__init__( + filters=filters, + kernel_size=kernel_size, + groups=groups, + padding="valid", + use_bias=True, + bias_initializer="he_normal", + **kwargs, + ) + self.explicit_padding = explicit_padding + self.filter_axis = 2 + self.kernel_norm_axes = tf.constant([0, 1]) + + def _init_norm(self): + """Set the norm of the weight vector.""" + kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.weight_v), axis=self.kernel_norm_axes)) + self.weight_g.assign(kernel_norm[:, tf.newaxis, tf.newaxis]) + + def _normalize_kernel(self): + """Generate normalized weights.""" + kernel = tf.nn.l2_normalize(self.weight_v, axis=self.kernel_norm_axes) * tf.transpose(self.weight_g) + self.kernel = tf.transpose(kernel) + + def build(self, input_shape): + if not self.built: + super().build(input_shape) + + self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True) + self.weight_v = self.kernel + + self.weight_g = self.add_weight( + name="weight_g", + shape=(int(self.weight_v.shape[self.filter_axis]), 1, 1), + initializer="ones", + dtype=self.weight_v.dtype, + trainable=True, + ) + self._init_norm() + self.bias = self.add_weight(name="bias", shape=(self.filters,), initializer="zeros", trainable=True) + + def call(self, inputs): + # TODO Matt: Assigning to attributes in call() is deeply sinful in TensorFlow, as it should be idempotent. + # This whole layer should be replaced by a layer that doesn't inherit from Conv1D, but instead calls + # a functional 1d convolution with normalized weights that it generates (but does not store!) + self._normalize_kernel() + + padded_inputs = tf.pad(inputs, ((0, 0), (self.explicit_padding, self.explicit_padding), (0, 0))) + output = super().call(padded_inputs) + + return output + + +class TFWav2Vec2NoLayerNormConvLayer(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = keras.layers.Conv1D( + filters=self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + strides=config.conv_stride[layer_id], + use_bias=config.conv_bias, + name="conv", + ) + self.activation = get_tf_activation(config.feat_extract_activation) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv", None) is not None: + with tf.name_scope(self.conv.name): + self.conv.build([None, None, self.in_conv_dim]) + + +class TFWav2Vec2LayerNormConvLayer(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = keras.layers.Conv1D( + filters=self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + strides=config.conv_stride[layer_id], + use_bias=config.conv_bias, + name="conv", + ) + self.layer_norm = keras.layers.LayerNormalization(name="layer_norm", epsilon=config.layer_norm_eps) + self.activation = get_tf_activation(config.feat_extract_activation) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv", None) is not None: + with tf.name_scope(self.conv.name): + self.conv.build([None, None, self.in_conv_dim]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.out_conv_dim]) + + +class TFWav2Vec2GroupNormConvLayer(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, layer_id: int = 0, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = keras.layers.Conv1D( + filters=self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + strides=config.conv_stride[layer_id], + use_bias=config.conv_bias, + name="conv", + ) + self.activation = get_tf_activation(config.feat_extract_activation) + self.layer_norm = TFWav2Vec2GroupNorm( + groups=self.out_conv_dim, epsilon=config.layer_norm_eps, name="layer_norm" + ) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv", None) is not None: + with tf.name_scope(self.conv.name): + self.conv.build([None, None, self.in_conv_dim]) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.out_conv_dim]) + + +class TFWav2Vec2PositionalConvEmbedding(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.conv = TFWav2Vec2WeightNormConv1D( + filters=config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + groups=config.num_conv_pos_embedding_groups, + explicit_padding=config.num_conv_pos_embeddings // 2, + name="conv", + ) + self.padding = TFWav2Vec2SamePadLayer(config.num_conv_pos_embeddings) + self.activation = get_tf_activation(config.feat_extract_activation) + self.config = config + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv", None) is not None: + with tf.name_scope(self.conv.name): + self.conv.build([None, None, self.config.hidden_size]) + + +class TFWav2Vec2SamePadLayer(keras.layers.Layer): + def __init__(self, num_conv_pos_embeddings, **kwargs): + super().__init__(**kwargs) + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def call(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, : -self.num_pad_remove, :] + return hidden_states + + +class TFWav2Vec2FeatureEncoder(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, **kwargs: Any) -> None: + super().__init__(**kwargs) + + if config.feat_extract_norm == "group": + conv_layers = [TFWav2Vec2GroupNormConvLayer(config, layer_id=0, name=f"conv_layers.{0}")] + [ + TFWav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1, name=f"conv_layers.{i + 1}") + for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [ + TFWav2Vec2LayerNormConvLayer(config, layer_id=i, name=f"conv_layers.{i}") + for i in range(config.num_feat_extract_layers) + ] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = conv_layers + + def call(self, input_values): + hidden_states = tf.expand_dims(input_values, -1) + for conv_layer in self.conv_layers: + hidden_states = conv_layer(hidden_states) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "conv_layers", None) is not None: + for conv_layer in self.conv_layers: + with tf.name_scope(conv_layer.name): + conv_layer.build(None) + + +class TFWav2Vec2FeatureExtractor(TFWav2Vec2FeatureEncoder): + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + warnings.warn( + f"The class `{self.__class__.__name__}` has been depreciated " + "and will be removed in Transformers v5. " + f"Use `{self.__class__.__bases__[0].__name__}` instead.", + FutureWarning, + ) + + +class TFWav2Vec2FeatureProjection(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, **kwargs): + super().__init__(**kwargs) + + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.projection = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="projection", + ) + self.dropout = keras.layers.Dropout(rate=config.feat_proj_dropout) + self.config = config + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + return hidden_states, norm_hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.conv_dim[-1]]) + if getattr(self, "projection", None) is not None: + with tf.name_scope(self.projection.name): + self.projection.build([None, None, self.config.conv_dim[-1]]) + + +# Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFWav2Vec2 +class TFWav2Vec2Attention(keras.layers.Layer): + """Multi-headed attention from "Attention Is All You Need""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = keras.layers.Dropout(dropout) + self.head_dim = embed_dim // num_heads + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj") + self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj") + self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj") + self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj") + + def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int): + return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3)) + + def call( + self, + hidden_states: tf.Tensor, + key_value_states: tf.Tensor | None = None, + past_key_value: Tuple[Tuple[tf.Tensor]] | None = None, + attention_mask: tf.Tensor | None = None, + layer_head_mask: tf.Tensor | None = None, + training: Optional[bool] = False, + ) -> Tuple[tf.Tensor, tf.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, embed_dim = shape_list(hidden_states) + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = tf.concat([past_key_value[0], key_states], axis=2) + value_states = tf.concat([past_key_value[1], value_states], axis=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape) + key_states = tf.reshape(key_states, proj_shape) + value_states = tf.reshape(value_states, proj_shape) + + src_len = shape_list(key_states)[1] + attn_weights = tf.matmul(query_states, key_states, transpose_b=True) + + tf.debugging.assert_equal( + shape_list(attn_weights), + [bsz * self.num_heads, tgt_len, src_len], + message=( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {shape_list(attn_weights)}" + ), + ) + + if attention_mask is not None: + tf.debugging.assert_equal( + shape_list(attention_mask), + [bsz, 1, tgt_len, src_len], + message=( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {shape_list(attention_mask)}" + ), + ) + + attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype) + attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_weights = stable_softmax(attn_weights, axis=-1) + + if layer_head_mask is not None: + tf.debugging.assert_equal( + shape_list(layer_head_mask), + [self.num_heads], + message=( + f"Head mask for a single layer should be of size {(self.num_heads)}, but is" + f" {shape_list(layer_head_mask)}" + ), + ) + + attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape( + attn_weights, (bsz, self.num_heads, tgt_len, src_len) + ) + attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len)) + + attn_probs = self.dropout(attn_weights, training=training) + attn_output = tf.matmul(attn_probs, value_states) + + tf.debugging.assert_equal( + shape_list(attn_output), + [bsz * self.num_heads, tgt_len, self.head_dim], + message=( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {shape_list(attn_output)}" + ), + ) + + attn_output = tf.transpose( + tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3) + ) + attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim)) + + attn_output = self.out_proj(attn_output) + attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + + return attn_output, attn_weights, past_key_value + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "k_proj", None) is not None: + with tf.name_scope(self.k_proj.name): + self.k_proj.build([None, None, self.embed_dim]) + if getattr(self, "q_proj", None) is not None: + with tf.name_scope(self.q_proj.name): + self.q_proj.build([None, None, self.embed_dim]) + if getattr(self, "v_proj", None) is not None: + with tf.name_scope(self.v_proj.name): + self.v_proj.build([None, None, self.embed_dim]) + if getattr(self, "out_proj", None) is not None: + with tf.name_scope(self.out_proj.name): + self.out_proj.build([None, None, self.embed_dim]) + + +class TFWav2Vec2FeedForward(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, **kwargs): + super().__init__(**kwargs) + + self.intermediate_dropout = keras.layers.Dropout(config.activation_dropout) + + self.intermediate_dense = keras.layers.Dense( + units=config.intermediate_size, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="intermediate_dense", + ) + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + + self.output_dense = keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + bias_initializer="zeros", + name="output_dense", + ) + self.output_dropout = keras.layers.Dropout(config.hidden_dropout) + self.config = config + + def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states, training=training) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states, training=training) + return hidden_states + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "intermediate_dense", None) is not None: + with tf.name_scope(self.intermediate_dense.name): + self.intermediate_dense.build([None, None, self.config.hidden_size]) + if getattr(self, "output_dense", None) is not None: + with tf.name_scope(self.output_dense.name): + self.output_dense.build([None, None, self.config.intermediate_size]) + + +class TFWav2Vec2EncoderLayer(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, **kwargs): + super().__init__(**kwargs) + self.attention = TFWav2Vec2Attention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + name="attention", + ) + self.dropout = keras.layers.Dropout(config.hidden_dropout) + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.feed_forward = TFWav2Vec2FeedForward(config, name="feed_forward") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + output_attentions: Optional[bool] = False, + training: bool = False, + ) -> Tuple[tf.Tensor]: + attn_residual = hidden_states + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, training=training + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + if getattr(self, "feed_forward", None) is not None: + with tf.name_scope(self.feed_forward.name): + self.feed_forward.build(None) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.config.hidden_size]) + + +class TFWav2Vec2EncoderLayerStableLayerNorm(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, **kwargs): + super().__init__(**kwargs) + self.attention = TFWav2Vec2Attention( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + name="attention", + ) + self.dropout = keras.layers.Dropout(config.hidden_dropout) + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.feed_forward = TFWav2Vec2FeedForward(config, name="feed_forward") + self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm") + self.config = config + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + output_attentions: Optional[bool] = False, + training: bool = False, + ) -> Tuple[tf.Tensor]: + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, training=training + ) + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "attention", None) is not None: + with tf.name_scope(self.attention.name): + self.attention.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + if getattr(self, "feed_forward", None) is not None: + with tf.name_scope(self.feed_forward.name): + self.feed_forward.build(None) + if getattr(self, "final_layer_norm", None) is not None: + with tf.name_scope(self.final_layer_norm.name): + self.final_layer_norm.build([None, None, self.config.hidden_size]) + + +class TFWav2Vec2Encoder(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.pos_conv_embed = TFWav2Vec2PositionalConvEmbedding(config, name="pos_conv_embed") + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.dropout = keras.layers.Dropout(config.hidden_dropout) + self.layer = [TFWav2Vec2EncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + hidden_states = hidden_states * tf.expand_dims(attention_mask, -1) + attention_mask = _expand_mask(attention_mask) + else: + attention_mask = None + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states, training=training) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + if training and (dropout_probability < self.config.layerdrop): # skip the layer + continue + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "pos_conv_embed", None) is not None: + with tf.name_scope(self.pos_conv_embed.name): + self.pos_conv_embed.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +class TFWav2Vec2EncoderStableLayerNorm(keras.layers.Layer): + def __init__(self, config: Wav2Vec2Config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.pos_conv_embed = TFWav2Vec2PositionalConvEmbedding(config, name="pos_conv_embed") + self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.dropout = keras.layers.Dropout(config.hidden_dropout) + self.layer = [ + TFWav2Vec2EncoderLayerStableLayerNorm(config, name=f"layers.{i}") for i in range(config.num_hidden_layers) + ] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor | None = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + training: Optional[bool] = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + hidden_states = hidden_states * tf.expand_dims(attention_mask, -1) + attention_mask = _expand_mask(attention_mask) + else: + attention_mask = None + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.dropout(hidden_states, training=training) + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + if training and (dropout_probability < self.config.layerdrop): # skip the layer + continue + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return TFBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "pos_conv_embed", None) is not None: + with tf.name_scope(self.pos_conv_embed.name): + self.pos_conv_embed.build(None) + if getattr(self, "layer_norm", None) is not None: + with tf.name_scope(self.layer_norm.name): + self.layer_norm.build([None, None, self.config.hidden_size]) + if getattr(self, "layer", None) is not None: + for layer in self.layer: + with tf.name_scope(layer.name): + layer.build(None) + + +@keras_serializable +class TFWav2Vec2MainLayer(keras.layers.Layer): + config_class = Wav2Vec2Config + + def __init__(self, config: Wav2Vec2Config, **kwargs): + super().__init__(**kwargs) + self.config = config + self.feature_extractor = TFWav2Vec2FeatureEncoder(config, name="feature_extractor") + self.feature_projection = TFWav2Vec2FeatureProjection(config, name="feature_projection") + + if config.do_stable_layer_norm: + self.encoder = TFWav2Vec2EncoderStableLayerNorm(config, name="encoder") + else: + self.encoder = TFWav2Vec2Encoder(config, name="encoder") + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0: + self.masked_spec_embed = self.add_weight( + shape=(self.config.hidden_size,), initializer="uniform", trainable=True, name="masked_spec_embed" + ) + if getattr(self, "feature_extractor", None) is not None: + with tf.name_scope(self.feature_extractor.name): + self.feature_extractor.build(None) + if getattr(self, "feature_projection", None) is not None: + with tf.name_scope(self.feature_projection.name): + self.feature_projection.build(None) + if getattr(self, "encoder", None) is not None: + with tf.name_scope(self.encoder.name): + self.encoder.build(None) + + def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _mask_hidden_states(self, hidden_states: tf.Tensor, mask_time_indices: tf.Tensor | None = None): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + batch_size, sequence_length, hidden_size = shape_list(hidden_states) + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states = tf.where( + tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool), + self.masked_spec_embed[tf.newaxis, tf.newaxis, :], + hidden_states, + ) + + elif self.config.mask_time_prob > 0: + # generate indices & apply SpecAugment along time axis + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + min_masks=2, + ) + hidden_states = tf.where( + tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool), + self.masked_spec_embed[tf.newaxis, tf.newaxis, :], + hidden_states, + ) + + # apply SpecAugment along feature axis + if self.config.mask_feature_prob > 0: + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + ) + hidden_states = tf.where(mask_feature_indices[:, tf.newaxis, :], hidden_states, 0) + + return hidden_states + + @unpack_inputs + def call( + self, + input_values: tf.Tensor, + attention_mask: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + **kwargs: Any, + ): + extract_features = self.feature_extractor(tf.cast(input_values, tf.float32), training=training) + # extract_features = tf.transpose(extract_features, perm=(0, 2, 1)) + + if attention_mask is not None: + # compute real output lengths according to convolution formula + output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1)) + + attention_mask = tf.sequence_mask( + output_lengths, maxlen=shape_list(extract_features)[1], dtype=extract_features.dtype + ) + + hidden_states, extract_features = self.feature_projection(extract_features, training=training) + + mask_time_indices = kwargs.get("mask_time_indices", None) + if training: + hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return TFWav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class TFWav2Vec2PreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Wav2Vec2Config + base_model_prefix = "wav2vec2" + main_input_name = "input_values" + + @property + def input_signature(self): + return { + "input_values": tf.TensorSpec((None, None), tf.float32, name="input_values"), + "attention_mask": tf.TensorSpec((None, None), tf.float32, name="attention_mask"), + } + + @property + def dummy_inputs(self): + return { + "input_values": tf.random.uniform(shape=(1, 500), dtype=tf.float32), + "attention_mask": tf.ones(shape=(1, 500), dtype=tf.float32), + } + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + logger.warning( + f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish " + "to train/fine-tune this model, you need a GPU or a TPU" + ) + + def _get_feat_extract_output_lengths(self, input_lengths, add_adapter=None): + """ + Computes the output length of the convolutional layers + """ + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + return tf.math.floordiv(input_length - kernel_size, stride) + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + return input_lengths + + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: tf.Tensor, add_adapter=None + ): + non_padded_lengths = tf.math.cumsum(attention_mask, axis=-1)[:, -1] + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + output_lengths = tf.cast(output_lengths, tf.int32) + batch_size = tf.shape(attention_mask)[0] + # check device here + attention_mask = tf.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, name="attention_mask" + ) # these two operations makes sure that all values before the output lengths idxs are attended to + ## check device + attention_mask = tf.tensor_scatter_nd_update( + attention_mask, + indices=tf.stack([tf.range(batch_size), output_lengths - 1], axis=1), + updates=tf.ones([batch_size], dtype=attention_mask.dtype), + ) + attention_mask = tf.reverse(attention_mask, axis=[-1]) + attention_mask = tf.cumsum(attention_mask, axis=-1) + attention_mask = tf.reverse(attention_mask, axis=[-1]) + attention_mask = tf.cast(attention_mask, tf.bool) + return attention_mask + + +WAV2VEC2_START_DOCSTRING = r""" + + This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it + as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and + behavior. + + + + TensorFlow models and layers in `transformers` accept two formats as input: + + - having all inputs as keyword arguments (like PyTorch models), or + - having all inputs as a list, tuple or dict in the first positional argument. + + The reason the second format is supported is that Keras methods prefer this format when passing inputs to models + and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just + pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second + format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with + the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first + positional argument: + + - a single Tensor with `input_values` only and nothing else: `model(input_values)` + - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: + `model([input_values, attention_mask])` or `model([input_values, attention_mask, token_type_ids])` + - a dictionary with one or several input Tensors associated to the input names given in the docstring: + `model({"input_values": input_values, "token_type_ids": token_type_ids})` + + Note that when creating models and layers with + [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry + about any of this, as you can just pass inputs like you would to any other Python function! + + + + Args: + config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +WAV2VEC2_INPUTS_DOCSTRING = r""" + Args: + input_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and + [`PreTrainedTokenizer.encode`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_values` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_values` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False``): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" + + +@add_start_docstrings( + "The bare TFWav2Vec2 Model transformer outputing raw hidden-states without any specific head on top.", + WAV2VEC2_START_DOCSTRING, +) +class TFWav2Vec2Model(TFWav2Vec2PreTrainedModel): + def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.config = config + self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2") + + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC) + @unpack_inputs + def call( + self, + input_values: tf.Tensor, + attention_mask: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]: + """ + + Returns: + + Example: + + ```python + >>> from transformers import AutoProcessor, TFWav2Vec2Model + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h") + + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1 + >>> hidden_states = model(input_values).last_hidden_state + ```""" + + output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states + output_attentions = output_attentions if output_attentions else self.config.output_attentions + return_dict = return_dict if return_dict else self.config.return_dict + + outputs = self.wav2vec2( + input_values=input_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "wav2vec2", None) is not None: + with tf.name_scope(self.wav2vec2.name): + self.wav2vec2.build(None) + + +@add_start_docstrings( + """TFWav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + WAV2VEC2_START_DOCSTRING, +) +class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel): + def __init__(self, config: Wav2Vec2Config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2") + self.dropout = keras.layers.Dropout(config.final_dropout) + self.lm_head = keras.layers.Dense(config.vocab_size, name="lm_head") + self.output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2.feature_extractor.trainable = False + + @unpack_inputs + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC) + def call( + self, + input_values: tf.Tensor, + attention_mask: tf.Tensor | None = None, + token_type_ids: tf.Tensor | None = None, + position_ids: tf.Tensor | None = None, + head_mask: tf.Tensor | None = None, + inputs_embeds: tf.Tensor | None = None, + output_attentions: Optional[bool] = None, + labels: tf.Tensor | None = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_values` docstring) Tokens with indices set to `-100` are ignored (masked), + the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + + Returns: + + Example: + + ```python + >>> import tensorflow as tf + >>> from transformers import AutoProcessor, TFWav2Vec2ForCTC + >>> from datasets import load_dataset + >>> import soundfile as sf + + >>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") + + + >>> def map_to_array(batch): + ... speech, _ = sf.read(batch["file"]) + ... batch["speech"] = speech + ... return batch + + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> ds = ds.map(map_to_array) + + >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1 + >>> logits = model(input_values).logits + >>> predicted_ids = tf.argmax(logits, axis=-1) + + >>> transcription = processor.decode(predicted_ids[0]) + + >>> # compute loss + >>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST" + + >>> # Pass transcription as `text` to encode labels + >>> labels = processor(text=transcription, return_tensors="tf").input_ids + + >>> loss = model(input_values, labels=labels).loss + ```""" + if labels is not None and tf.reduce_max(labels) >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + outputs = self.wav2vec2( + input_values=input_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, training=training) + + logits = self.lm_head(hidden_states) + + if labels is not None: + attention_mask = ( + attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32) + ) + input_lengths = self.wav2vec2._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1)) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = tf.cast(labels >= 0, tf.int32) + target_lengths = tf.reduce_sum(labels_mask, axis=-1) + + loss = tf.nn.ctc_loss( + logits=logits, + labels=labels, + logit_length=input_lengths, + label_length=target_lengths, + blank_index=self.config.pad_token_id, + logits_time_major=False, + ) + + if self.config.ctc_loss_reduction == "sum": + loss = tf.reduce_sum(loss) + if self.config.ctc_loss_reduction == "mean": + loss = tf.reduce_mean(loss) + + loss = tf.reshape(loss, (1,)) + else: + loss = None + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "wav2vec2", None) is not None: + with tf.name_scope(self.wav2vec2.name): + self.wav2vec2.build(None) + if getattr(self, "lm_head", None) is not None: + with tf.name_scope(self.lm_head.name): + self.lm_head.build([None, None, self.output_hidden_size]) + + +class TFWav2Vec2ForSequenceClassification(TFWav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.wav2vec2 = TFWav2Vec2MainLayer(config, name="wav2vec2") + self.num_layers = config.num_hidden_layers + 1 + with tf.name_scope(self._name_scope()): + if config.use_weighted_layer_sum: + self.layer_weights = self.add_weight( + shape=(self.num_layers,), initializer="ones", trainable=True, name="layer_weights" + ) + self.config = config + self.projector = keras.layers.Dense(units=config.classifier_proj_size, name="projector") + self.classifier = keras.layers.Dense(units=config.num_labels, activation=None, name="classifier") + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2.feature_extractor.trainable = False + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for layer in self.wav2vec2.layers: + layer.trainable = False + + @unpack_inputs + def call( + self, + input_values: tf.Tensor, + attention_mask: tf.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + labels: tf.Tensor | None = None, + training: bool = False, + ) -> TFSequenceClassifierOutput | Tuple[tf.Tensor]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = tf.stack(hidden_states, axis=1) + norm_weights = tf.nn.softmax(self.layer_weights, axis=-1) + hidden_states = tf.reduce_sum(hidden_states * tf.reshape(norm_weights, [-1, 1, 1]), axis=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = tf.reduce_mean(hidden_states, axis=1) + else: + padding_mask = self._get_feature_vector_attention_mask(shape_list(hidden_states)[1], attention_mask) + padding_mask_float = tf.cast(padding_mask, hidden_states.dtype) + hidden_states = tf.multiply(hidden_states, tf.expand_dims(padding_mask_float, axis=-1)) + pooled_output = tf.divide( + tf.reduce_sum(hidden_states, axis=1), tf.expand_dims(tf.reduce_sum(padding_mask_float, axis=1), axis=1) + ) + logits = self.classifier(pooled_output) + loss = None + if labels is not None: + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss = loss_fn(tf.reshape(labels, [-1]), tf.reshape(logits, [-1, self.config.num_labels])) + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def build(self, input_shape=None): + if self.built: + return + self.built = True + if getattr(self, "wav2vec2", None) is not None: + with tf.name_scope(self.wav2vec2.name): + self.wav2vec2.build(None) + if getattr(self, "projector", None) is not None: + with tf.name_scope(self.projector.name): + self.projector.build([None, None, self.config.hidden_size]) + if getattr(self, "classifier", None) is not None: + with tf.name_scope(self.classifier.name): + self.classifier.build([None, None, self.config.classifier_proj_size]) + + +__all__ = ["TFWav2Vec2ForCTC", "TFWav2Vec2Model", "TFWav2Vec2PreTrainedModel", "TFWav2Vec2ForSequenceClassification"] diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2/modeling_wav2vec2.py b/docs/transformers/build/lib/transformers/models/wav2vec2/modeling_wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac0e21486e7c49b22c7d3e78a87c046153610f8 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -0,0 +1,2728 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Wav2Vec2 model.""" + +import math +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + MaskedLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + cached_file, + is_peft_available, + is_safetensors_available, + logging, + replace_return_docstrings, +) +from .configuration_wav2vec2 import Wav2Vec2Config + + +WAV2VEC2_ADAPTER_PT_FILE = "adapter.{}.bin" +WAV2VEC2_ADAPTER_SAFE_FILE = "adapter.{}.safetensors" + +if is_safetensors_available(): + from safetensors.torch import load_file as safe_load_file + + +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 2 + +# General docstring +_CONFIG_FOR_DOC = "Wav2Vec2Config" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 768] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 53.48 + +# Audio class docstring +_SEQ_CLASS_CHECKPOINT = "superb/wav2vec2-base-superb-ks" +_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" +_SEQ_CLASS_EXPECTED_LOSS = 6.54 + +# Frame class docstring +_FRAME_CLASS_CHECKPOINT = "anton-l/wav2vec2-base-superb-sd" +_FRAME_EXPECTED_OUTPUT = [0, 0] + +# Speaker Verification docstring +_XVECTOR_CHECKPOINT = "anton-l/wav2vec2-base-superb-sv" +_XVECTOR_EXPECTED_OUTPUT = 0.98 + + +@dataclass +class Wav2Vec2ForPreTrainingOutput(ModelOutput): + """ + Output type of [`Wav2Vec2ForPreTraining`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . + diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . + """ + + loss: Optional[torch.FloatTensor] = None + projected_states: Optional[torch.FloatTensor] = None + projected_quantized_states: Optional[torch.FloatTensor] = None + codevector_perplexity: Optional[torch.FloatTensor] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + contrastive_loss: Optional[torch.FloatTensor] = None + diversity_loss: Optional[torch.FloatTensor] = None + + +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +def _sample_negative_indices( + features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None +): + """ + Sample `num_negatives` vectors from feature vectors. + """ + batch_size, sequence_length = features_shape + + # generate indices of the positive vectors themselves, repeat them `num_negatives` times + sequence_length_range = np.arange(sequence_length) + + # get `num_negatives` random vector indices from the same utterance + sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) + + mask_time_indices = ( + mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool) + ) + + for batch_idx in range(batch_size): + high = mask_time_indices[batch_idx].sum() - 1 + mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]] + + feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives)) + sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives)) + # avoid sampling the same positive vector, but keep the distribution uniform + sampled_indices[sampled_indices >= feature_indices] += 1 + + # remap to actual indices + sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices] + + # correct for batch size + sampled_negative_indices[batch_idx] += batch_idx * sequence_length + + return sampled_negative_indices + + +class Wav2Vec2NoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Wav2Vec2LayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Wav2Vec2GroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +class Wav2Vec2PositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Wav2Vec2SamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +class Wav2Vec2FeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [ + Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [ + Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) + ] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + conv_layer.__call__, + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +class Wav2Vec2FeatureExtractor(Wav2Vec2FeatureEncoder): + def __init__(self, config): + super().__init__(config) + warnings.warn( + f"The class `{self.__class__.__name__}` has been depreciated " + "and will be removed in Transformers v5. " + f"Use `{self.__class__.__bases__[0].__name__}` instead.", + FutureWarning, + ) + + +class Wav2Vec2FeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2 +class Wav2Vec2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + bias: bool = True, + is_causal: bool = False, + config: Optional[Wav2Vec2Config] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + self.is_causal = is_causal + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Wav2Vec2 +class Wav2Vec2FlashAttention2(Wav2Vec2Attention): + """ + Wav2Vec2 flash attention module. This module inherits from `Wav2Vec2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Wav2Vec2FlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError("Wav2Vec2FlashAttention2 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout if self.training else 0.0, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Wav2Vec2SdpaAttention(Wav2Vec2Attention): + # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Wav2Vec2 + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + if output_attentions or layer_head_mask is not None: + # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Wav2Vec2Model is using Wav2Vec2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" + ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states, + key_value_states=key_value_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_states = self._shape(query_states, tgt_len, bsz) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. + is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + + # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, + # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + +WAV2VEC2_ATTENTION_CLASSES = { + "eager": Wav2Vec2Attention, + "sdpa": Wav2Vec2SdpaAttention, + "flash_attention_2": Wav2Vec2FlashAttention2, +} + + +class Wav2Vec2FeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class Wav2Vec2EncoderLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = WAV2VEC2_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = Wav2Vec2FeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states, attention_mask=None, output_attentions=False): + attn_residual = hidden_states + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states + self.feed_forward(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = WAV2VEC2_ATTENTION_CLASSES[config._attn_implementation]( + embed_dim=config.hidden_size, + num_heads=config.num_attention_heads, + dropout=config.attention_dropout, + is_decoder=False, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.feed_forward = Wav2Vec2FeedForward(config) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if getattr(config, "adapter_attn_dim", None) is not None: + self.adapter_layer = Wav2Vec2AttnAdapterLayer(config) + else: + self.adapter_layer = None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + attn_residual = hidden_states + hidden_states = self.layer_norm(hidden_states) + hidden_states, attn_weights, _ = self.attention( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = self.dropout(hidden_states) + hidden_states = attn_residual + hidden_states + hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) + + if self.adapter_layer is not None: + hidden_states = hidden_states + self.adapter_layer(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class Wav2Vec2Encoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def forward( + self, + hidden_states: torch.tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_attention_mask] = 0 + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Wav2Vec2EncoderStableLayerNorm(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList( + [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens are not attended to + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + position_embeddings = self.pos_conv_embed(hidden_states) + hidden_states = hidden_states + position_embeddings + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class Wav2Vec2GumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH + GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. + """ + + def __init__(self, config): + super().__init__() + self.num_groups = config.num_codevector_groups + self.num_vars = config.num_codevectors_per_group + + if config.codevector_dim % self.num_groups != 0: + raise ValueError( + f"`config.codevector_dim {config.codevector_dim} must be divisible " + f"by `config.num_codevector_groups` {self.num_groups} for concatenation" + ) + + # storage for codebook variables (codewords) + self.codevectors = nn.Parameter( + torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups) + ) + self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) + + # can be decayed for training + self.temperature = 2 + + @staticmethod + def _compute_perplexity(probs, mask=None): + if mask is not None: + mask_extended = mask.flatten()[:, None, None].expand(probs.shape) + probs = torch.where(mask_extended, probs, torch.zeros_like(probs)) + marginal_probs = probs.sum(dim=0) / mask.sum() + else: + marginal_probs = probs.mean(dim=0) + + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states, mask_time_indices=None): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +class Wav2Vec2Adapter(nn.Module): + def __init__(self, config): + super().__init__() + + # feature dim might need to be down-projected + if config.output_hidden_size != config.hidden_size: + self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) + self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size) + else: + self.proj = self.proj_layer_norm = None + + self.layers = nn.ModuleList(Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers)) + self.layerdrop = config.layerdrop + + def forward(self, hidden_states): + # down project hidden_states if necessary + if self.proj is not None and self.proj_layer_norm is not None: + hidden_states = self.proj(hidden_states) + hidden_states = self.proj_layer_norm(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + + for layer in self.layers: + layerdrop_prob = np.random.random() + if not self.training or (layerdrop_prob > self.layerdrop): + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Wav2Vec2AdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.output_hidden_size, + 2 * config.output_hidden_size, + config.adapter_kernel_size, + stride=config.adapter_stride, + padding=1, + ) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=1) + + return hidden_states + + +class Wav2Vec2AttnAdapterLayer(nn.Module): + def __init__(self, config): + """ + Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed + up training throughput. + """ + super().__init__() + self.input_dim = config.adapter_attn_dim + self.hidden_dim = config.hidden_size + + self.norm = nn.LayerNorm(self.hidden_dim) + self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim) + self.act_fn = nn.ReLU() + self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim) + + def forward(self, hidden_states: torch.FloatTensor): + hidden_states = self.norm(hidden_states) + + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +class Wav2Vec2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Wav2Vec2Config + base_model_prefix = "wav2vec2" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init. + if isinstance(module, Wav2Vec2ForPreTraining): + module.project_hid.reset_parameters() + module.project_q.reset_parameters() + module.project_hid._is_hf_initialized = True + module.project_q._is_hf_initialized = True + # gumbel softmax requires special init + elif isinstance(module, Wav2Vec2GumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, Wav2Vec2PositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, Wav2Vec2FeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + + return input_lengths + + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None + ): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + output_lengths = output_lengths.to(torch.long) + + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + def _get_adapters(self): + if self.config.adapter_attn_dim is None: + raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.") + + adapter_weights = {} + for name, module in self.named_modules(): + if isinstance(module, Wav2Vec2AttnAdapterLayer): + for param_name, param in module.named_parameters(): + adapter_weights[".".join([name, param_name])] = param + + if isinstance(self, Wav2Vec2ForCTC): + for name, param in self.lm_head.named_parameters(): + adapter_weights[".".join(["lm_head", name])] = param + + return adapter_weights + + def init_adapter_layers(self): + """ + (Re-)initialize attention adapter layers and lm head for adapter-only fine-tuning + """ + # init attention adapters + for module in self.modules(): + if isinstance(module, Wav2Vec2AttnAdapterLayer): + self._init_weights(module) + + # init lm head + if isinstance(self, Wav2Vec2ForCTC): + self._init_weights(self.lm_head) + + def load_adapter(self, target_lang: str, force_load=True, **kwargs): + r""" + Load a language adapter model from a pre-trained adapter model. + + Parameters: + target_lang (`str`): + Has to be a language id of an existing adapter weight. Adapter weights are stored in the format + adapter..safetensors or adapter..bin + force_load (`bool`, defaults to `True`): + Whether the weights shall be loaded even if `target_lang` matches `self.target_lang`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download: + Deprecated and ignored. All downloads are now resumed by default when possible. + Will be removed in v5 of Transformers. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. + + + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + Examples: + + ```python + >>> from transformers import Wav2Vec2ForCTC, AutoProcessor + + >>> ckpt = "facebook/mms-1b-all" + >>> processor = AutoProcessor.from_pretrained(ckpt) + >>> model = Wav2Vec2ForCTC.from_pretrained(ckpt, target_lang="eng") + >>> # set specific language + >>> processor.tokenizer.set_target_lang("spa") + >>> model.load_adapter("spa") + ``` + """ + if self.config.adapter_attn_dim is None: + raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.") + + if target_lang == self.target_lang and not force_load: + logger.warning(f"Adapter weights are already set to {target_lang}.") + return + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", None) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + token = kwargs.pop("token", None) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + model_path_or_id = self.config._name_or_path + state_dict = None + + # 1. Let's first try loading a safetensors adapter weight + if use_safetensors is not False: + filepath = WAV2VEC2_ADAPTER_SAFE_FILE.format(target_lang) + + try: + weight_path = cached_file( + model_path_or_id, + filename=filepath, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + cache_dir=cache_dir, + ) + + state_dict = safe_load_file(weight_path) + + except EnvironmentError: + if use_safetensors: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + + except Exception: + # For any other exception, we throw a generic error. + if use_safetensors: + raise EnvironmentError( + f"Can't load the model for '{model_path_or_id}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a" + f" directory containing a file named {filepath}." + ) + + # 2. If this didn't work let's try loading a PyTorch adapter weight + if state_dict is None: + filepath = WAV2VEC2_ADAPTER_PT_FILE.format(target_lang) + + try: + weight_path = cached_file( + model_path_or_id, + filename=filepath, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + cache_dir=cache_dir, + ) + + state_dict = torch.load( + weight_path, + map_location="cpu", + weights_only=True, + ) + + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + + except Exception: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{model_path_or_id}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a" + f" directory containing a file named {filepath}." + ) + + adapter_weights = self._get_adapters() + unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys()) + missing_keys = set(adapter_weights.keys()) - set(state_dict.keys()) + + if len(unexpected_keys) > 0: + raise ValueError(f"The adapter weights {weight_path} has unexpected keys: {', '.join(unexpected_keys)}.") + elif len(missing_keys) > 0: + raise ValueError(f"The adapter weights {weight_path} has missing keys: {', '.join(missing_keys)}.") + + # make sure now vocab size is correct + target_vocab_size = state_dict["lm_head.weight"].shape[0] + if target_vocab_size != self.config.vocab_size: + self.lm_head = nn.Linear( + self.config.output_hidden_size, target_vocab_size, device=self.device, dtype=self.dtype + ) + self.config.vocab_size = target_vocab_size + + # make sure that adapter weights are put in exactly the same precision and device placement and overwritten adapter weights + state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()} + self.load_state_dict(state_dict, strict=False) + + # set target language corectly + self.target_lang = target_lang + + +WAV2VEC2_START_DOCSTRING = r""" + Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech + Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael + Auli. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use + it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and + behavior. + + Parameters: + config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +WAV2VEC2_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, such as + [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be + passed to avoid degraded performance when doing batched inference. For such models `input_values` should + simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly + different results depending on whether `input_values` is padded or not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top.", + WAV2VEC2_START_DOCSTRING, +) +class Wav2Vec2Model(Wav2Vec2PreTrainedModel): + def __init__(self, config: Wav2Vec2Config): + super().__init__(config) + self.config = config + self.feature_extractor = Wav2Vec2FeatureEncoder(config) + self.feature_projection = Wav2Vec2FeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + if config.do_stable_layer_norm: + self.encoder = Wav2Vec2EncoderStableLayerNorm(config) + else: + self.encoder = Wav2Vec2Encoder(config) + + self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.feature_extractor._freeze_parameters() + + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Wav2Vec2BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""Wav2Vec2 Model with a quantizer and `VQ` head on top.""", WAV2VEC2_START_DOCSTRING) +class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): + def __init__(self, config: Wav2Vec2Config): + super().__init__(config) + self.wav2vec2 = Wav2Vec2Model(config) + self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) + + self.quantizer = Wav2Vec2GumbelVectorQuantizer(config) + + self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + + # Initialize weights and apply final processing + self.post_init() + + def set_gumbel_temperature(self, temperature: int): + """ + Set the Gumbel softmax temperature to a given value. Only necessary for training + """ + self.quantizer.temperature = temperature + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2.feature_extractor._freeze_parameters() + + @staticmethod + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 0.1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as( + target_features + ) + + # apply temperature + logits = logits / temperature + return logits + + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Wav2Vec2ForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.BoolTensor] = None, + sampled_negative_indices: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Wav2Vec2ForPreTrainingOutput]: + r""" + mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*): + Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss. + Required input for pre-training. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining + >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices + >>> from datasets import load_dataset + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") + >>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1 + + >>> # compute masked indices + >>> batch_size, raw_sequence_length = input_values.shape + >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item() + >>> mask_time_indices = _compute_mask_indices( + ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2 + ... ) + >>> sampled_negative_indices = _sample_negative_indices( + ... features_shape=(batch_size, sequence_length), + ... num_negatives=model.config.num_negatives, + ... mask_time_indices=mask_time_indices, + ... ) + >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long) + >>> sampled_negative_indices = torch.tensor( + ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long + ... ) + + >>> with torch.no_grad(): + ... outputs = model(input_values, mask_time_indices=mask_time_indices) + + >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states) + >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1) + + >>> # show that cosine similarity is much higher than random + >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5 + tensor(True) + + >>> # for contrastive loss training model should be put into train mode + >>> model = model.train() + >>> loss = model( + ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices + ... ).loss + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if mask_time_indices is not None: + mask_time_indices = mask_time_indices.to(torch.bool) + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + mask_time_indices=mask_time_indices, + return_dict=return_dict, + ) + + # 1. project all transformed features (including masked) to final vq dim + transformer_features = self.project_hid(outputs[0]) + + # 2. quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1]) + + if attention_mask is not None: + # compute reduced attention_mask correponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + quantized_features, codevector_perplexity = self.quantizer( + extract_features, mask_time_indices=mask_time_indices + ) + + quantized_features = quantized_features.to(self.project_q.weight.dtype) + quantized_features = self.project_q(quantized_features) + + loss = contrastive_loss = diversity_loss = None + if sampled_negative_indices is not None: + batch_size, sequence_length, hidden_size = quantized_features.shape + + # for training, we sample negatives + # 3. sample K negatives (distractors) quantized states for contrastive loss + # if attention_mask is passed, make sure that padded feature vectors cannot be sampled + # sample negative quantized vectors BTC => (BxT)C + negative_quantized_features = quantized_features.view(-1, hidden_size)[ + sampled_negative_indices.long().view(-1) + ] + negative_quantized_features = negative_quantized_features.view( + batch_size, sequence_length, -1, hidden_size + ).permute(2, 0, 1, 3) + + # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` + # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf + logits = self.compute_contrastive_logits( + quantized_features[None, :], + negative_quantized_features, + transformer_features, + self.config.contrastive_logits_temperature, + ) + + # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low), + # its cosine similarity will be masked + neg_is_pos = (quantized_features == negative_quantized_features).all(-1) + + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + + # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = + # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) + logits = logits.transpose(0, 2).reshape(-1, logits.size(0)) + target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten() + + contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum") + # 7. compute diversity loss: \mathbf{L}_d + num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups + diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum() + + # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d + loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss + + if not return_dict: + if loss is not None: + return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return Wav2Vec2ForPreTrainingOutput( + loss=loss, + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + contrastive_loss=contrastive_loss, + diversity_loss=diversity_loss, + ) + + +@add_start_docstrings("""Wav2Vec2 Model with a `language modeling` head on top.""", WAV2VEC2_START_DOCSTRING) +class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + warnings.warn( + "The class `Wav2Vec2ForMaskedLM` is deprecated. Please use `Wav2Vec2ForCTC` instead.", FutureWarning + ) + + self.wav2vec2 = Wav2Vec2Model(config) + self.dropout = nn.Dropout(config.final_dropout) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + def forward( + self, + input_values: torch.FloatTensor, + attention_mask: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, MaskedLMOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.wav2vec2( + input_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.lm_head(hidden_states) + + if not return_dict: + output = (logits,) + outputs[2:] + return output + + return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) + + +@add_start_docstrings( + """Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + WAV2VEC2_START_DOCSTRING, + """ + target_lang (`str`, *optional*): + Language id of adapter weights. Adapter weights are stored in the format adapter..safetensors or + adapter..bin. Only relevant when using an instance of [`Wav2Vec2ForCTC`] with adapters. Uses 'eng' by + default. + """, +) +class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel): + def __init__(self, config, target_lang: Optional[str] = None): + super().__init__(config) + + self.wav2vec2 = Wav2Vec2Model(config) + self.dropout = nn.Dropout(config.final_dropout) + + self.target_lang = target_lang + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + def tie_weights(self): + """ + This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when + passing `target_lang=...` to `from_pretrained(...)`. + + This method is **not** supposed to be called by the user and is prone to be changed in the future. + """ + + # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to + # correctly load adapter layers for Wav2Vec2 so that we do not have to introduce a new API to + # [`PreTrainedModel`]. While slightly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is + # ok to repurpose this function here. + target_lang = self.target_lang + + if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None: + raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.") + elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None: + logger.info("By default `target_lang` is set to 'eng'.") + elif target_lang is not None: + self.load_adapter(target_lang, force_load=True) + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None and labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + Wav2Vec2 Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """, + WAV2VEC2_START_DOCSTRING, +) +class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Sequence classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)" + ) + self.wav2vec2 = Wav2Vec2Model(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameters will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_SEQ_CLASS_CHECKPOINT, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, + expected_loss=_SEQ_CLASS_EXPECTED_LOSS, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + hidden_states[~expand_padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Wav2Vec2 Model with a frame classification head on top for tasks like Speaker Diarization. + """, + WAV2VEC2_START_DOCSTRING, +) +class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Audio frame classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)" + ) + self.wav2vec2 = Wav2Vec2Model(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.num_labels = config.num_labels + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_FRAME_CLASS_CHECKPOINT, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_FRAME_EXPECTED_OUTPUT, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class AMSoftmaxLoss(nn.Module): + def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): + super(AMSoftmaxLoss, self).__init__() + self.scale = scale + self.margin = margin + self.num_labels = num_labels + self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True) + self.loss = nn.CrossEntropyLoss() + + def forward(self, hidden_states, labels): + labels = labels.flatten() + weight = nn.functional.normalize(self.weight, dim=0) + hidden_states = nn.functional.normalize(hidden_states, dim=1) + cos_theta = torch.mm(hidden_states, weight) + psi = cos_theta - self.margin + + onehot = nn.functional.one_hot(labels, self.num_labels) + logits = self.scale * torch.where(onehot.bool(), psi, cos_theta) + loss = self.loss(logits, labels) + + return loss + + +class TDNNLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id] + self.out_conv_dim = config.tdnn_dim[layer_id] + self.kernel_size = config.tdnn_kernel[layer_id] + self.dilation = config.tdnn_dilation[layer_id] + + self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim) + self.activation = nn.ReLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if is_peft_available(): + from peft.tuners.lora import LoraLayer + + if is_peft_available(): + if isinstance(self.kernel, LoraLayer): + warnings.warn( + "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. " + "You should exclude TDNNLayer from LoRA's target modules.", + ) + + # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up + hidden_states = hidden_states.transpose(1, 2) + weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2) + hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +@add_start_docstrings( + """ + Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + WAV2VEC2_START_DOCSTRING, +) +class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wav2vec2 = Wav2Vec2Model(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0]) + + tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))] + self.tdnn = nn.ModuleList(tdnn_layers) + + self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim) + self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim) + + self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) + + self.init_weights() + + def freeze_feature_extractor(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + warnings.warn( + "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " + "Please use the equivalent `freeze_feature_encoder` method instead.", + FutureWarning, + ) + self.freeze_feature_encoder() + + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2.parameters(): + param.requires_grad = False + + def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the TDNN layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size in self.config.tdnn_kernel: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1) + + return input_lengths + + @add_start_docstrings_to_model_forward(WAV2VEC2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_XVECTOR_CHECKPOINT, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_XVECTOR_EXPECTED_OUTPUT, + ) + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, XVectorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + + for tdnn_layer in self.tdnn: + hidden_states = tdnn_layer(hidden_states) + + # Statistic Pooling + if attention_mask is None: + mean_features = hidden_states.mean(dim=1) + std_features = hidden_states.std(dim=1) + else: + feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) + tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) + mean_features = [] + std_features = [] + for i, length in enumerate(tdnn_output_lengths): + mean_features.append(hidden_states[i, :length].mean(dim=0)) + std_features.append(hidden_states[i, :length].std(dim=0)) + mean_features = torch.stack(mean_features) + std_features = torch.stack(std_features) + statistic_pooling = torch.cat([mean_features, std_features], dim=-1) + + output_embeddings = self.feature_extractor(statistic_pooling) + logits = self.classifier(output_embeddings) + + loss = None + if labels is not None: + loss = self.objective(logits, labels) + + if not return_dict: + output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return XVectorOutput( + loss=loss, + logits=logits, + embeddings=output_embeddings, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "Wav2Vec2ForAudioFrameClassification", + "Wav2Vec2ForCTC", + "Wav2Vec2ForMaskedLM", + "Wav2Vec2ForPreTraining", + "Wav2Vec2ForSequenceClassification", + "Wav2Vec2ForXVector", + "Wav2Vec2Model", + "Wav2Vec2PreTrainedModel", +] diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2/processing_wav2vec2.py b/docs/transformers/build/lib/transformers/models/wav2vec2/processing_wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..077f5617198b9af4c93df88480938c5928834b17 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2/processing_wav2vec2.py @@ -0,0 +1,186 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Speech processor class for Wav2Vec2 +""" + +import warnings +from contextlib import contextmanager +from typing import List, Optional, Union + +from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput +from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor +from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer + + +class Wav2Vec2ProcessorKwargs(ProcessingKwargs, total=False): + _defaults = {} + + +class Wav2Vec2Processor(ProcessorMixin): + r""" + Constructs a Wav2Vec2 processor which wraps a Wav2Vec2 feature extractor and a Wav2Vec2 CTC tokenizer into a single + processor. + + [`Wav2Vec2Processor`] offers all the functionalities of [`Wav2Vec2FeatureExtractor`] and [`PreTrainedTokenizer`]. + See the docstring of [`~Wav2Vec2Processor.__call__`] and [`~Wav2Vec2Processor.decode`] for more information. + + Args: + feature_extractor (`Wav2Vec2FeatureExtractor`): + An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is a required input. + tokenizer ([`PreTrainedTokenizer`]): + An instance of [`PreTrainedTokenizer`]. The tokenizer is a required input. + """ + + feature_extractor_class = "Wav2Vec2FeatureExtractor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + try: + return super().from_pretrained(pretrained_model_name_or_path, **kwargs) + except (OSError, ValueError): + warnings.warn( + f"Loading a tokenizer inside {cls.__name__} from a config that does not" + " include a `tokenizer_class` attribute is deprecated and will be " + "removed in v5. Please add `'tokenizer_class': 'Wav2Vec2CTCTokenizer'`" + " attribute to either your `config.json` or `tokenizer_config.json` " + "file to suppress this warning: ", + FutureWarning, + ) + + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs) + tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs) + + return cls(feature_extractor=feature_extractor, tokenizer=tokenizer) + + def __call__( + self, + audio: AudioInput = None, + text: Optional[Union[str, List[str], TextInput, PreTokenizedInput]] = None, + images=None, + videos=None, + **kwargs: Unpack[Wav2Vec2ProcessorKwargs], + ): + """ + When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's + [`~Wav2Vec2FeatureExtractor.__call__`] and returns its output. If used in the context + [`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's + [`~PreTrainedTokenizer.__call__`]. Please refer to the docstring of the above two methods for more information. + """ + + if "raw_speech" in kwargs: + warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.") + audio = kwargs.pop("raw_speech") + + if audio is None and text is None: + raise ValueError("You need to specify either an `audio` or `text` input to process.") + + output_kwargs = self._merge_kwargs( + Wav2Vec2ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + # For backward compatibility + if self._in_target_context_manager: + return self.current_processor( + audio, + **output_kwargs["audio_kwargs"], + **output_kwargs["text_kwargs"], + **output_kwargs["common_kwargs"], + ) + + if audio is not None: + inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"]) + if text is not None: + encodings = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + if text is None: + return inputs + elif audio is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + return inputs + + def pad(self, *args, **kwargs): + """ + When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's + [`~Wav2Vec2FeatureExtractor.pad`] and returns its output. If used in the context + [`~Wav2Vec2Processor.as_target_processor`] this method forwards all its arguments to PreTrainedTokenizer's + [`~PreTrainedTokenizer.pad`]. Please refer to the docstring of the above two methods for more information. + """ + # For backward compatibility + if self._in_target_context_manager: + return self.current_processor.pad(*args, **kwargs) + + input_features = kwargs.pop("input_features", None) + labels = kwargs.pop("labels", None) + if len(args) > 0: + input_features = args[0] + args = args[1:] + + if input_features is not None: + input_features = self.feature_extractor.pad(input_features, *args, **kwargs) + if labels is not None: + labels = self.tokenizer.pad(labels, **kwargs) + + if labels is None: + return input_features + elif input_features is None: + return labels + else: + input_features["labels"] = labels["input_ids"] + return input_features + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @contextmanager + def as_target_processor(self): + """ + Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning + Wav2Vec2. + """ + warnings.warn( + "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your " + "labels by using the argument `text` of the regular `__call__` method (either in the same call as " + "your audio inputs, or in a separate call." + ) + self._in_target_context_manager = True + self.current_processor = self.tokenizer + yield + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + + +__all__ = ["Wav2Vec2Processor"] diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2/tokenization_wav2vec2.py b/docs/transformers/build/lib/transformers/models/wav2vec2/tokenization_wav2vec2.py new file mode 100644 index 0000000000000000000000000000000000000000..ad51a4e4d028ff0c44ea630b022962f86e933d57 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2/tokenization_wav2vec2.py @@ -0,0 +1,924 @@ +# coding=utf-8 +# Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization class for Wav2Vec2.""" + +import json +import os +import warnings +from dataclasses import dataclass +from itertools import groupby +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import numpy as np + +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import AddedToken, BatchEncoding +from ...utils import ( + ModelOutput, + PaddingStrategy, + TensorType, + add_end_docstrings, + is_flax_available, + is_tf_available, + is_torch_available, + logging, + to_py_obj, +) + + +logger = logging.get_logger(__name__) + + +if TYPE_CHECKING: + if is_torch_available(): + import torch + if is_tf_available(): + import tensorflow as tf + if is_flax_available(): + import jax.numpy as jnp # noqa: F401 + + +VOCAB_FILES_NAMES = { + "vocab_file": "vocab.json", + "tokenizer_config_file": "tokenizer_config.json", +} + + +# Wav2Vec2 has no max input length + +WAV2VEC2_KWARGS_DOCSTRING = r""" + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. +""" + +ListOfDict = List[Dict[str, Union[int, str]]] + + +@dataclass +class Wav2Vec2CTCTokenizerOutput(ModelOutput): + """ + Output type of [` Wav2Vec2CTCTokenizer`], with transcription. + + Args: + text (list of `str` or `str`): + Decoded logits in text from. Usually the speech transcription. + char_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`): + Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char + offsets can be used to compute time stamps for each charater. Total logit score of the beam associated with + produced text. + word_offsets (list of `List[Dict[str, Union[int, str]]]` or `List[Dict[str, Union[int, str]]]`): + Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets + can be used to compute time stamps for each word. + """ + + text: Union[List[str], str] + char_offsets: Union[List[ListOfDict], ListOfDict] = None + word_offsets: Union[List[ListOfDict], ListOfDict] = None + + +class Wav2Vec2CTCTokenizer(PreTrainedTokenizer): + """ + Constructs a Wav2Vec2CTC tokenizer. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to + the superclass for more information regarding such methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sentence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sentence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + word_delimiter_token (`str`, *optional*, defaults to `"|"`): + The token used for defining the end of a word. + do_lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to accept lowercase input and lowercase the output when decoding. + target_lang (`str`, *optional*): + A target language the tokenizer should set by default. `target_lang` has to be defined for multi-lingual, + nested vocabulary such as [facebook/mms-1b-all](https://huggingface.co/facebook/mms-1b-all). + + **kwargs + Additional keyword arguments passed along to [`PreTrainedTokenizer`] + """ + + vocab_files_names = VOCAB_FILES_NAMES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + unk_token="", + pad_token="", + word_delimiter_token="|", + replace_word_delimiter_char=" ", + do_lower_case=False, + target_lang=None, + **kwargs, + ): + self._word_delimiter_token = word_delimiter_token + + self.do_lower_case = do_lower_case + self.replace_word_delimiter_char = replace_word_delimiter_char + self.target_lang = target_lang + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.vocab = json.load(vocab_handle) + + # if target lang is defined vocab must be a nested dict + # with each target lang being one vocabulary + if target_lang is not None: + self.encoder = self.vocab[target_lang] + else: + self.encoder = self.vocab + + self.decoder = {v: k for k, v in self.encoder.items()} + + super().__init__( + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + do_lower_case=do_lower_case, + word_delimiter_token=word_delimiter_token, + replace_word_delimiter_char=replace_word_delimiter_char, + target_lang=target_lang, + **kwargs, + ) + + # make sure that tokens made of several + # characters are not split at tokenization + for token in self.encoder.keys(): + if len(token) > 1: + self.add_tokens(AddedToken(token, rstrip=True, lstrip=True, normalized=False)) + + def set_target_lang(self, target_lang: str): + """ + Set the target language of a nested multi-lingual dictionary + """ + if self.vocab == self.encoder: + raise ValueError(f"{self.vocab} is not a multi-lingual, nested tokenizer. Cannot set target language.") + + if target_lang not in self.vocab: + raise ValueError(f"{target_lang} does not exist. Choose one of {', '.join(self.vocab.keys())}.") + + self.target_lang = target_lang + self.init_kwargs["target_lang"] = target_lang + self.encoder = self.vocab[target_lang] + self.decoder = {v: k for k, v in self.encoder.items()} + + # make sure that tokens made of several + # characters are not split at tokenization + for token in self.encoder.keys(): + if len(token) > 1: + self.add_tokens(AddedToken(token, rstrip=True, lstrip=True, normalized=False)) + + @property + def word_delimiter_token(self) -> str: + """ + `str`: Word delimiter token. Log an error if used while not having been set. + """ + if self._word_delimiter_token is None and self.verbose: + logger.error("Using word_delimiter_token, but it is not set yet.") + return None + return str(self._word_delimiter_token) + + @property + def word_delimiter_token_id(self) -> Optional[int]: + """ + `Optional[int]`: Id of the word_delimiter_token in the vocabulary. Returns `None` if the token has not been + set. + """ + if self._word_delimiter_token is None: + return None + return self.convert_tokens_to_ids(self.word_delimiter_token) + + @word_delimiter_token.setter + def word_delimiter_token(self, value): + self._word_delimiter_token = value + + @word_delimiter_token_id.setter + def word_delimiter_token_id(self, value): + self._word_delimiter_token = self.convert_tokens_to_ids(value) + + @property + def vocab_size(self) -> int: + return len(self.decoder) + + def get_vocab(self) -> Dict: + vocab = dict(self.encoder) + vocab.update(self.added_tokens_encoder) + return vocab + + def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: + # Overwritten to never strip! + to_add = [] + for token in new_tokens: + if isinstance(token, str): + to_add.append(AddedToken(token, rstrip=False, lstrip=False, normalized=False)) + else: + to_add.append(token) + + return super()._add_tokens(to_add, special_tokens) + + def _tokenize(self, text, **kwargs): + """ + Converts a string into a sequence of tokens (string), using the tokenizer. + """ + if self.do_lower_case: + text = text.upper() + + return list(text.replace(" ", self.word_delimiter_token)) + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) in an index (integer) using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) in a token (str) using the vocab.""" + result = self.decoder.get(index, self.unk_token) + return result + + def convert_tokens_to_string( + self, + tokens: List[str], + group_tokens: bool = True, + spaces_between_special_tokens: bool = False, + output_char_offsets: bool = False, + output_word_offsets: bool = False, + ) -> Dict[str, Union[str, float]]: + """ + Converts a connectionist-temporal-classification (CTC) output tokens into a single string. + """ + if len(tokens) == 0: + return {"text": "", "char_offsets": [], "word_offsets": []} + # group same tokens into non-repeating tokens in CTC style decoding + if group_tokens: + chars, char_repetitions = zip(*((token, len(list(group_iter))) for token, group_iter in groupby(tokens))) + else: + chars = tokens + char_repetitions = len(tokens) * [1] + + # filter self.pad_token which is used as CTC-blank token + processed_chars = list(filter(lambda char: char != self.pad_token, chars)) + + # replace delimiter token + processed_chars = [ + self.replace_word_delimiter_char if char == self.word_delimiter_token else char for char in processed_chars + ] + + # retrieve offsets + char_offsets = word_offsets = None + if output_char_offsets or output_word_offsets: + char_offsets = self._compute_offsets(char_repetitions, chars, self.pad_token) + + if len(char_offsets) != len(processed_chars): + raise ValueError( + f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}" + " have to be of the same length, but are: " + f"`len(offsets)`: {len(char_offsets)} and `len(processed_tokens)`:" + f" {len(processed_chars)}" + ) + + # set tokens to correct processed token + for i, char in enumerate(processed_chars): + char_offsets[i]["char"] = char + + # retrieve word offsets from character offsets + word_offsets = None + if output_word_offsets: + word_offsets = self._get_word_offsets(char_offsets, self.replace_word_delimiter_char) + + # don't output chars if not set to True + if not output_char_offsets: + char_offsets = None + + # join to string + join_char = " " if spaces_between_special_tokens else "" + string = join_char.join(processed_chars).strip() + + if self.do_lower_case: + string = string.lower() + + return {"text": string, "char_offsets": char_offsets, "word_offsets": word_offsets} + + @staticmethod + def _compute_offsets( + char_repetitions: List[int], chars: List[str], ctc_token: int + ) -> List[Dict[str, Union[str, int]]]: + end_indices = np.asarray(char_repetitions).cumsum() + start_indices = np.concatenate(([0], end_indices[:-1])) + + offsets = [ + {"char": t, "start_offset": s, "end_offset": e} for t, s, e in zip(chars, start_indices, end_indices) + ] + + # filter out CTC token + offsets = list(filter(lambda offsets: offsets["char"] != ctc_token, offsets)) + return offsets + + @staticmethod + def _get_word_offsets( + offsets: Dict[str, Union[str, float]], word_delimiter_char: str = " " + ) -> Dict[str, Union[str, float]]: + word_offsets = [] + + last_state = "SPACE" + word = "" + start_offset = 0 + end_offset = 0 + for i, offset in enumerate(offsets): + char = offset["char"] + state = "SPACE" if char == word_delimiter_char else "WORD" + + if state == last_state: + # If we are in the same state as before, we simply repeat what we've done before + end_offset = offset["end_offset"] + word += char + else: + # Switching state + if state == "SPACE": + # Finishing a word + word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) + else: + # Starting a new word + start_offset = offset["start_offset"] + end_offset = offset["end_offset"] + word = char + + last_state = state + if last_state == "WORD": + word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset}) + + return word_offsets + + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): + if is_split_into_words: + text = " " + text + return (text, kwargs) + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + group_tokens: bool = True, + spaces_between_special_tokens: bool = False, + output_word_offsets: Optional[bool] = False, + output_char_offsets: Optional[bool] = False, + ) -> str: + """ + special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the + same as tokens of the base vocabulary and therefore the function `convert_tokens_to_string` has to be called on + the whole token list and not individually on added tokens + """ + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + + result = [] + for token in filtered_tokens: + if skip_special_tokens and ( + token in self.all_special_ids or (token != self.pad_token and token in self.all_special_tokens) + ): + continue + result.append(token) + + string_output = self.convert_tokens_to_string( + result, + group_tokens=group_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + output_word_offsets=output_word_offsets, + output_char_offsets=output_char_offsets, + ) + + text = string_output["text"] + + clean_up_tokenization_spaces = ( + clean_up_tokenization_spaces + if clean_up_tokenization_spaces is not None + else self.clean_up_tokenization_spaces + ) + if clean_up_tokenization_spaces: + text = self.clean_up_tokenization(text) + + if output_word_offsets or output_char_offsets: + return Wav2Vec2CTCTokenizerOutput( + text=text, + char_offsets=string_output["char_offsets"], + word_offsets=string_output["word_offsets"], + ) + else: + return text + + # overwritten from `tokenization_utils_base.py` because tokenizer can output + # `ModelOutput` which should not be a list for batched output and + # because we need docs for `output_char_offsets` here + def batch_decode( + self, + sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + output_char_offsets: bool = False, + output_word_offsets: bool = False, + **kwargs, + ) -> List[str]: + """ + Convert a list of lists of token ids into a list of strings by calling decode. + + Args: + sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether or not to clean up the tokenization spaces. + output_char_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output character offsets. Character offsets can be used in combination with the + sampling rate and model downsampling rate to compute the time-stamps of transcribed characters. + + + + Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make + use of `output_char_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched + output. + + + + output_word_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate + and model downsampling rate to compute the time-stamps of transcribed words. + + + + Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make + use of `output_word_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched + output. + + + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `List[str]` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded + sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when + `output_char_offsets == True` or `output_word_offsets == True`. + """ + batch_decoded = [ + self.decode( + seq, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + output_char_offsets=output_char_offsets, + output_word_offsets=output_word_offsets, + **kwargs, + ) + for seq in sequences + ] + if output_char_offsets or output_word_offsets: + # transform list of dicts to dict of lists + return Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in batch_decoded] for k in batch_decoded[0]}) + + return batch_decoded + + # overwritten from `tokenization_utils_base.py` because we need docs for `output_char_offsets` + # and `output_word_offsets` here + def decode( + self, + token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + output_char_offsets: bool = False, + output_word_offsets: bool = False, + **kwargs, + ) -> str: + """ + Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special + tokens and clean up tokenization spaces. + + Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`. + + Args: + token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`): + List of tokenized input ids. Can be obtained using the `__call__` method. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + clean_up_tokenization_spaces (`bool`, *optional*): + Whether or not to clean up the tokenization spaces. + output_char_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output character offsets. Character offsets can be used in combination with the + sampling rate and model downsampling rate to compute the time-stamps of transcribed characters. + + + + Please take a look at the example below to better understand how to make use of `output_char_offsets`. + + + + output_word_offsets (`bool`, *optional*, defaults to `False`): + Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate + and model downsampling rate to compute the time-stamps of transcribed words. + + + + Please take a look at the example below to better understand how to make use of `output_word_offsets`. + + + + kwargs (additional keyword arguments, *optional*): + Will be passed to the underlying model specific decode method. + + Returns: + `str` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded + sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when + `output_char_offsets == True` or `output_word_offsets == True`. + + Example: + + ```python + >>> # Let's see how to retrieve time steps for a model + >>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC + >>> from datasets import load_dataset + >>> import datasets + >>> import torch + + >>> # import model, feature extractor, tokenizer + >>> model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h") + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h") + + >>> # load first sample of English common_voice + >>> dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", split="train", streaming=True, trust_remote_code=True) + >>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000)) + >>> dataset_iter = iter(dataset) + >>> sample = next(dataset_iter) + + >>> # forward sample through model to get greedily predicted transcription ids + >>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values + >>> logits = model(input_values).logits[0] + >>> pred_ids = torch.argmax(logits, axis=-1) + + >>> # retrieve word stamps (analogous commands for `output_char_offsets`) + >>> outputs = tokenizer.decode(pred_ids, output_word_offsets=True) + >>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate + >>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate + + >>> word_offsets = [ + ... { + ... "word": d["word"], + ... "start_time": round(d["start_offset"] * time_offset, 2), + ... "end_time": round(d["end_offset"] * time_offset, 2), + ... } + ... for d in outputs.word_offsets + ... ] + >>> # compare word offsets with audio `en_train_0/common_voice_en_19121553.mp3` online on the dataset viewer: + >>> # https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0/viewer/en + >>> word_offsets[:3] + [{'word': 'THE', 'start_time': 0.7, 'end_time': 0.78}, {'word': 'TRICK', 'start_time': 0.88, 'end_time': 1.08}, {'word': 'APPEARS', 'start_time': 1.2, 'end_time': 1.64}] + ```""" + # Convert inputs to python lists + token_ids = to_py_obj(token_ids) + + return self._decode( + token_ids=token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + output_char_offsets=output_char_offsets, + output_word_offsets=output_word_offsets, + **kwargs, + ) + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + return (vocab_file,) + + +class Wav2Vec2Tokenizer(PreTrainedTokenizer): + """ + Constructs a Wav2Vec2 tokenizer. + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to + the superclass for more information regarding such methods. + + Args: + vocab_file (`str`): + File containing the vocabulary. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sentence token. + eos_token (`str`, *optional*, defaults to `""`): + The end of sentence token. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + word_delimiter_token (`str`, *optional*, defaults to `"|"`): + The token used for defining the end of a word. + do_lower_case (`bool`, *optional*, defaults to `False`): + Whether or not to lowercase the output when decoding. + do_normalize (`bool`, *optional*, defaults to `False`): + Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly + improve the performance for some models, *e.g.*, + [wav2vec2-lv60](https://huggingface.co/models?search=lv60). + return_attention_mask (`bool`, *optional*, defaults to `False`): + Whether or not [`~Wav2Vec2Tokenizer.__call__`] should return `attention_mask`. + + + + Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as + [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using + `attention_mask`. For such models, `input_values` should simply be padded with 0 and no `attention_mask` + should be passed. + + For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as + [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should be + passed for batched inference. + + + + **kwargs + Additional keyword arguments passed along to [`PreTrainedTokenizer`] + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = { + "vocab_file": { + "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json" + }, + "tokenizer_config_file": { + "facebook/wav2vec2-base-960h": ( + "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json" + ), + }, + } + model_input_names = ["input_values", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + unk_token="", + pad_token="", + word_delimiter_token="|", + do_lower_case=False, + do_normalize=False, + return_attention_mask=False, + **kwargs, + ): + warnings.warn( + "The class `Wav2Vec2Tokenizer` is deprecated and will be removed in version 5 of Transformers. Please use" + " `Wav2Vec2Processor` or `Wav2Vec2CTCTokenizer` instead.", + FutureWarning, + ) + + self._word_delimiter_token = word_delimiter_token + + self.do_lower_case = do_lower_case + self.return_attention_mask = return_attention_mask + self.do_normalize = do_normalize + + with open(vocab_file, encoding="utf-8") as vocab_handle: + self.encoder = json.load(vocab_handle) + + self.decoder = {v: k for k, v in self.encoder.items()} + + super().__init__( + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + do_lower_case=do_lower_case, + do_normalize=do_normalize, + return_attention_mask=return_attention_mask, + word_delimiter_token=word_delimiter_token, + **kwargs, + ) + + @property + def word_delimiter_token(self) -> str: + """ + `str`: Padding token. Log an error if used while not having been set. + """ + if self._word_delimiter_token is None and self.verbose: + logger.error("Using word_delimiter_token, but it is not set yet.") + return None + return str(self._word_delimiter_token) + + @property + def word_delimiter_token_id(self) -> Optional[int]: + """ + `Optional[int]`: Id of the word_delimiter_token in the vocabulary. Returns `None` if the token has not been + set. + """ + if self._word_delimiter_token is None: + return None + return self.convert_tokens_to_ids(self.word_delimiter_token) + + @word_delimiter_token.setter + def word_delimiter_token(self, value): + self._word_delimiter_token = value + + @word_delimiter_token_id.setter + def word_delimiter_token_id(self, value): + self._word_delimiter_token = self.convert_tokens_to_ids(value) + + @add_end_docstrings(WAV2VEC2_KWARGS_DOCSTRING) + def __call__( + self, + raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], + padding: Union[bool, str, PaddingStrategy] = False, + max_length: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + padding_side: Optional[str] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences. + + Args: + raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`): + The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float + values, a list of numpy array or a list of list of float values. Must be mono channel audio, not + stereo, i.e. single float per timestep. + + padding_side (`str`, *optional*): + The side on which the model should have padding applied. Should be selected between ['right', 'left']. + Default value is picked from the class attribute of the same name. + """ + + is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 + if is_batched_numpy and len(raw_speech.shape) > 2: + raise ValueError(f"Only mono-channel audio is supported for input to {self}") + is_batched = is_batched_numpy or ( + isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) + ) + + # make sure input is in list format + if is_batched and not isinstance(raw_speech[0], np.ndarray): + raw_speech = [np.asarray(speech) for speech in raw_speech] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech) + + # always return batch + if not is_batched: + raw_speech = [raw_speech] + + # zero-mean and unit-variance normalization + if self.do_normalize: + raw_speech = [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in raw_speech] + + # convert into correct format for padding + encoded_inputs = BatchEncoding({"input_values": raw_speech}) + + padded_inputs = self.pad( + encoded_inputs, + padding=padding, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + padding_side=padding_side, + return_attention_mask=self.return_attention_mask, + return_tensors=return_tensors, + verbose=verbose, + ) + + return padded_inputs + + @property + def vocab_size(self) -> int: + return len(self.decoder) + + def get_vocab(self) -> Dict: + return dict(self.encoder, **self.added_tokens_encoder) + + def _convert_token_to_id(self, token: str) -> int: + """Converts a token (str) in an index (integer) using the vocab.""" + return self.encoder.get(token, self.encoder.get(self.unk_token)) + + def _convert_id_to_token(self, index: int) -> str: + """Converts an index (integer) in a token (str) using the vocab.""" + result = self.decoder.get(index, self.unk_token) + return result + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + """ + Converts a connectionist-temporal-classification (CTC) output tokens into a single string. + """ + # group same tokens into non-repeating tokens in CTC style decoding + grouped_tokens = [token_group[0] for token_group in groupby(tokens)] + + # filter self.pad_token which is used as CTC-blank token + filtered_tokens = list(filter(lambda token: token != self.pad_token, grouped_tokens)) + + # replace delimiter token + string = "".join([" " if token == self.word_delimiter_token else token for token in filtered_tokens]).strip() + + if self.do_lower_case: + string = string.lower() + + return string + + def _decode( + self, + token_ids: List[int], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + **kwargs, + ) -> str: + """ + special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the + same as tokens of the base vocabulary and therefore the function `convert_tokens_to_string` has to be called on + the whole token list and not individually on added tokens + """ + filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) + + result = [] + for token in filtered_tokens: + if skip_special_tokens and ( + token in self.all_special_ids or (token != self.pad_token and token in self.all_special_tokens) + ): + continue + result.append(token) + + text = self.convert_tokens_to_string(result) + + clean_up_tokenization_spaces = ( + clean_up_tokenization_spaces + if clean_up_tokenization_spaces is not None + else self.clean_up_tokenization_spaces + ) + if clean_up_tokenization_spaces: + clean_text = self.clean_up_tokenization(text) + return clean_text + else: + return text + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + with open(vocab_file, "w", encoding="utf-8") as f: + f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") + + return (vocab_file,) + + +__all__ = ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"] diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2_bert/__init__.py b/docs/transformers/build/lib/transformers/models/wav2vec2_bert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7520263c51bcd3b63747ee0a55e1baff24a777f9 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2_bert/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_wav2vec2_bert import * + from .modeling_wav2vec2_bert import * + from .processing_wav2vec2_bert import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py b/docs/transformers/build/lib/transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..db52cc5baed3675957c48abe18d50dd59faf4222 --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py @@ -0,0 +1,313 @@ +# coding=utf-8 +# Copyright 2024 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wav2Vec2Bert model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Wav2Vec2BertConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Wav2Vec2BertModel`]. It is used to + instantiate an Wav2Vec2Bert model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2Bert + [facebook/wav2vec2-bert-rel-pos-large](https://huggingface.co/facebook/wav2vec2-bert-rel-pos-large) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*): + Vocabulary size of the Wav2Vec2Bert model. Defines the number of different tokens that can be + represented by the `inputs_ids` passed when calling [`Wav2Vec2BertModel`]. Vocabulary size of the + model. Defines the different tokens that can be represented by the *inputs_ids* passed to the forward + method of [`Wav2Vec2BertModel`]. + hidden_size (`int`, *optional*, defaults to 1024): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 4096): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + feature_projection_input_dim (`int`, *optional*, defaults to 160): + Input dimension of this model, i.e the dimension after processing input audios with [`SeamlessM4TFeatureExtractor`] or [`Wav2Vec2BertProcessor`]. + hidden_act (`str` or `function`, *optional*, defaults to `"swish"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + activation_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for activations inside the fully connected layer. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + feat_proj_dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for the feature projection. + final_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for the final projection layer of [`Wav2Vec2BertForCTC`]. + layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more + details. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + apply_spec_augment (`bool`, *optional*, defaults to `True`): + Whether to apply *SpecAugment* data augmentation to the outputs of the feature encoder. For reference see + [SpecAugment: A Simple Data Augmentation Method for Automatic Speech + Recognition](https://arxiv.org/abs/1904.08779). + mask_time_prob (`float`, *optional*, defaults to 0.05): + Percentage (between 0 and 1) of all feature vectors along the time axis which will be masked. The masking + procecure generates `mask_time_prob*len(time_axis)/mask_time_length ``independent masks over the axis. If + reasoning from the propability of each feature vector to be chosen as the start of the vector span to be + masked, *mask_time_prob* should be `prob_vector_start*mask_time_length`. Note that overlap may decrease the + actual percentage of masked vectors. This is only relevant if `apply_spec_augment is True`. + mask_time_length (`int`, *optional*, defaults to 10): + Length of vector span along the time axis. + mask_time_min_masks (`int`, *optional*, defaults to 2): + The minimum number of masks of length `mask_feature_length` generated along the time axis, each time step, + irrespectively of `mask_feature_prob`. Only relevant if `mask_time_prob*len(time_axis)/mask_time_length < + mask_time_min_masks`. + mask_feature_prob (`float`, *optional*, defaults to 0.0): + Percentage (between 0 and 1) of all feature vectors along the feature axis which will be masked. The + masking procecure generates `mask_feature_prob*len(feature_axis)/mask_time_length` independent masks over + the axis. If reasoning from the propability of each feature vector to be chosen as the start of the vector + span to be masked, *mask_feature_prob* should be `prob_vector_start*mask_feature_length`. Note that overlap + may decrease the actual percentage of masked vectors. This is only relevant if `apply_spec_augment is + True`. + mask_feature_length (`int`, *optional*, defaults to 10): + Length of vector span along the feature axis. + mask_feature_min_masks (`int`, *optional*, defaults to 0): + The minimum number of masks of length `mask_feature_length` generated along the feature axis, each time + step, irrespectively of `mask_feature_prob`. Only relevant if + `mask_feature_prob*len(feature_axis)/mask_feature_length < mask_feature_min_masks`. + ctc_loss_reduction (`str`, *optional*, defaults to `"sum"`): + Specifies the reduction to apply to the output of `torch.nn.CTCLoss`. Only relevant when training an + instance of [`Wav2Vec2BertForCTC`]. + ctc_zero_infinity (`bool`, *optional*, defaults to `False`): + Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly + occur when the inputs are too short to be aligned to the targets. Only relevant when training an instance + of [`Wav2Vec2BertForCTC`]. + use_weighted_layer_sum (`bool`, *optional*, defaults to `False`): + Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an + instance of [`Wav2Vec2BertForSequenceClassification`]. + classifier_proj_size (`int`, *optional*, defaults to 768): + Dimensionality of the projection before token mean-pooling for classification. + tdnn_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 1500)`): + A tuple of integers defining the number of output channels of each 1D convolutional layer in the *TDNN* + module of the *XVector* model. The length of *tdnn_dim* defines the number of *TDNN* layers. + tdnn_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 3, 3, 1, 1)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the *TDNN* module of the + *XVector* model. The length of *tdnn_kernel* has to match the length of *tdnn_dim*. + tdnn_dilation (`Tuple[int]` or `List[int]`, *optional*, defaults to `(1, 2, 3, 1, 1)`): + A tuple of integers defining the dilation factor of each 1D convolutional layer in *TDNN* module of the + *XVector* model. The length of *tdnn_dilation* has to match the length of *tdnn_dim*. + xvector_output_dim (`int`, *optional*, defaults to 512): + Dimensionality of the *XVector* embedding vectors. + pad_token_id (`int`, *optional*, defaults to 0): The id of the _beginning-of-stream_ token. + bos_token_id (`int`, *optional*, defaults to 1): The id of the _padding_ token. + eos_token_id (`int`, *optional*, defaults to 2): The id of the _end-of-stream_ token. + add_adapter (`bool`, *optional*, defaults to `False`): + Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very + useful for warm-starting Wav2Vec2Bert for SpeechEncoderDecoder models. + adapter_kernel_size (`int`, *optional*, defaults to 3): + Kernel size of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + adapter_stride (`int`, *optional*, defaults to 2): + Stride of the convolutional layers in the adapter network. Only relevant if `add_adapter is True`. + num_adapter_layers (`int`, *optional*, defaults to 1): + Number of convolutional layers that should be used in the adapter network. Only relevant if `add_adapter is + True`. + adapter_act (`str` or `function`, *optional*, defaults to `"relu"`): + The non-linear activation function (function or string) in the adapter layers. If string, `"gelu"`, + `"relu"`, `"selu"`, `"swish"` and `"gelu_new"` are supported. + use_intermediate_ffn_before_adapter (`bool`, *optional*, defaults to `False`): + Whether an intermediate feed-forward block should be stacked on top of the Wav2Vec2Bert Encoder and before the adapter network. + Only relevant if `add_adapter is True`. + output_hidden_size (`int`, *optional*): + Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant + if `add_adapter is True`. + position_embeddings_type (`str`, *optional*, defaults to `"relative_key"`): + Can be specified to : + - `rotary`, for rotary position embeddings. + - `relative`, for relative position embeddings. + - `relative_key`, for relative position embeddings as defined by Shaw in [Self-Attention + with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + If left to `None`, no relative position embeddings is applied. + rotary_embedding_base (`int`, *optional*, defaults to 10000): + If `"rotary"` position embeddings are used, defines the size of the embedding base. + max_source_positions (`int`, *optional*, defaults to 5000): + if `"relative"` position embeddings are used, defines the maximum source input positions. + left_max_position_embeddings (`int`, *optional*, defaults to 64): + If `"relative_key"` (aka Shaw) position embeddings are used, defines the left clipping value for relative positions. + right_max_position_embeddings (`int`, *optional*, defaults to 8): + If `"relative_key"` (aka Shaw) position embeddings are used, defines the right clipping value for relative positions. + conv_depthwise_kernel_size (`int`, *optional*, defaults to 31): + Kernel size of convolutional depthwise 1D layer in Conformer blocks. + conformer_conv_dropout (`float`, *optional*, defaults to 0.1): + The dropout probability for all convolutional layers in Conformer blocks. + Example: + + ```python + >>> from transformers import Wav2Vec2BertConfig, Wav2Vec2BertModel + + >>> # Initializing a Wav2Vec2Bert facebook/wav2vec2-bert-rel-pos-large style configuration + >>> configuration = Wav2Vec2BertConfig() + + >>> # Initializing a model (with random weights) from the facebook/wav2vec2-bert-rel-pos-large style configuration + >>> model = Wav2Vec2BertModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "wav2vec2-bert" + + def __init__( + self, + vocab_size=None, + hidden_size=1024, + num_hidden_layers=24, + num_attention_heads=16, + intermediate_size=4096, + feature_projection_input_dim=160, + hidden_act="swish", + hidden_dropout=0.0, + activation_dropout=0.0, + attention_dropout=0.0, + feat_proj_dropout=0.0, + final_dropout=0.1, + layerdrop=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + apply_spec_augment=True, + mask_time_prob=0.05, + mask_time_length=10, + mask_time_min_masks=2, + mask_feature_prob=0.0, + mask_feature_length=10, + mask_feature_min_masks=0, + ctc_loss_reduction="sum", + ctc_zero_infinity=False, + use_weighted_layer_sum=False, + classifier_proj_size=768, + tdnn_dim=(512, 512, 512, 512, 1500), + tdnn_kernel=(5, 3, 3, 1, 1), + tdnn_dilation=(1, 2, 3, 1, 1), + xvector_output_dim=512, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + add_adapter=False, + adapter_kernel_size=3, + adapter_stride=2, + num_adapter_layers=1, + adapter_act="relu", + use_intermediate_ffn_before_adapter=False, + output_hidden_size=None, + position_embeddings_type="relative_key", + rotary_embedding_base=10000, + max_source_positions=5000, + left_max_position_embeddings=64, + right_max_position_embeddings=8, + conv_depthwise_kernel_size=31, + conformer_conv_dropout=0.1, + **kwargs, + ): + super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_attention_heads = num_attention_heads + self.feature_projection_input_dim = feature_projection_input_dim + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.activation_dropout = activation_dropout + self.feat_proj_dropout = feat_proj_dropout + self.final_dropout = final_dropout + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.vocab_size = vocab_size + self.use_weighted_layer_sum = use_weighted_layer_sum + self.max_source_positions = max_source_positions + + if position_embeddings_type is not None and position_embeddings_type not in [ + "rotary", + "relative", + "relative_key", + ]: + raise ValueError( + """ + `position_embeddings_type` is not valid. It must be one of the following values: + `["rotary", "relative", "relative_key"]` or left as `None`. + """ + ) + self.position_embeddings_type = position_embeddings_type + self.rotary_embedding_base = rotary_embedding_base + self.left_max_position_embeddings = left_max_position_embeddings + self.right_max_position_embeddings = right_max_position_embeddings + + # Conformer-block related + self.conv_depthwise_kernel_size = conv_depthwise_kernel_size + self.conformer_conv_dropout = conformer_conv_dropout + + # fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779 + self.apply_spec_augment = apply_spec_augment + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length + self.mask_time_min_masks = mask_time_min_masks + self.mask_feature_prob = mask_feature_prob + self.mask_feature_length = mask_feature_length + self.mask_feature_min_masks = mask_feature_min_masks + + # ctc loss + self.ctc_loss_reduction = ctc_loss_reduction + self.ctc_zero_infinity = ctc_zero_infinity + + # adapter + self.add_adapter = add_adapter + self.adapter_kernel_size = adapter_kernel_size + self.adapter_stride = adapter_stride + self.num_adapter_layers = num_adapter_layers + self.adapter_act = adapter_act + self.output_hidden_size = output_hidden_size if output_hidden_size is not None else hidden_size + if use_intermediate_ffn_before_adapter and not add_adapter: + raise ValueError("`use_intermediate_ffn_before_adapter` is `True` but `add_adapter` is `False`.") + self.use_intermediate_ffn_before_adapter = use_intermediate_ffn_before_adapter + + # SequenceClassification-specific parameter. Feel free to ignore for other classes. + self.classifier_proj_size = classifier_proj_size + + # XVector-specific parameters. Feel free to ignore for other classes. + self.tdnn_dim = list(tdnn_dim) + self.tdnn_kernel = list(tdnn_kernel) + self.tdnn_dilation = list(tdnn_dilation) + self.xvector_output_dim = xvector_output_dim + + @property + def inputs_to_logits_ratio(self): + ratio = self.feature_projection_input_dim * 2 + if self.add_adapter: + ratio = ratio * (self.adapter_stride**self.num_adapter_layers) + return ratio + + +__all__ = ["Wav2Vec2BertConfig"] diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py b/docs/transformers/build/lib/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..33510654dcca5ec24940f744653df1e909db3fae --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2_bert/convert_wav2vec2_seamless_checkpoint.py @@ -0,0 +1,217 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Wav2Vec2Bert BERT checkpoint.""" + +import argparse + +import torch +import torchaudio +from fairseq2.data import Collater +from fairseq2.data.audio import WaveformToFbankConverter +from fairseq2.nn.padding import get_seqs_and_padding_mask +from seamless_communication.models.conformer_shaw import load_conformer_shaw_model + +from transformers import ( + SeamlessM4TFeatureExtractor, + Wav2Vec2BertConfig, + Wav2Vec2BertModel, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + + +wav2vec_convert_list = [ + ("encoder_frontend.model_dim_proj", "feature_projection.projection"), + ("encoder_frontend.post_extract_layer_norm", "feature_projection.layer_norm"), + ("encoder_frontend.pos_encoder.conv", "encoder.pos_conv_embed.conv"), + ("encoder.inner.layers", "encoder.layers"), + ("encoder.inner_layer_norm", "encoder.layer_norm"), + ("encoder.adaptor_layers", "adapter.layers"), + ("inner_proj", "intermediate_dense"), + ("self_attn.output_proj", "self_attn.linear_out"), + ("output_proj", "output_dense"), + ("self_attn.k_proj", "self_attn.linear_k"), + ("self_attn.v_proj", "self_attn.linear_v"), + ("self_attn.q_proj", "self_attn.linear_q"), + ("self_attn.sdpa.u_bias", "self_attn.pos_bias_u"), + ("self_attn.sdpa.v_bias", "self_attn.pos_bias_v"), + ("self_attn.sdpa.rel_k_embed", "self_attn.distance_embedding"), + ("self_attn.sdpa.r_proj", "self_attn.linear_pos"), + ("conv.pointwise_conv1", "conv_module.pointwise_conv1"), + ("conv.pointwise_conv2", "conv_module.pointwise_conv2"), + ("conv.depthwise_conv", "conv_module.depthwise_conv"), + ("conv.layer_norm", "conv_module.depthwise_layer_norm"), + ("conv_layer_norm", "conv_module.layer_norm"), + ("encoder.proj1", "intermediate_ffn.intermediate_dense"), + ("encoder.proj2", "intermediate_ffn.output_dense"), + ("encoder.layer_norm", "inner_layer_norm"), + ("masker.temporal_mask_embed", "masked_spec_embed"), +] + +keys_to_remove = { + "quantizer.entry_proj", + "final_proj", + "final_target_proj", + "quantizer.entries", + "quantizer.num_updates", +} + + +def param_count(model): + return sum(p[1].numel() for p in model.named_parameters() if "final_proj" not in p[0]) + + +def _convert_model( + original_model, + hf_model, + convert_list, +): + state_dict = original_model.state_dict() + + for k, v in list(state_dict.items()): + new_key = k + for old_layer_name, new_layer_name in convert_list: + if old_layer_name in new_key: + new_key = new_key.replace(old_layer_name, new_layer_name) + + # must do it by hand + if ".layer_norm" in new_key and new_key.split(".layer_norm")[0][-1].isnumeric(): + new_key = new_key.replace("layer_norm", "final_layer_norm") + + add_key = True + for key in keys_to_remove: + if key in new_key: + state_dict.pop(k) + add_key = False + break + + if add_key: + state_dict[new_key] = state_dict.pop(k) + + extra_keys = set(state_dict.keys()) - set(hf_model.state_dict().keys()) + extra_keys = set({k for k in extra_keys if "num_updates" not in k}) # filter unecessary param + missing_keys = set(hf_model.state_dict().keys()) - set(state_dict.keys()) + if len(extra_keys) != 0: + raise ValueError(f"extra keys found: {extra_keys}") + if len(missing_keys) != 0: + raise ValueError(f"missing keys: {missing_keys}") + hf_model.load_state_dict(state_dict, strict=True) + n_params = param_count(hf_model) + + logger.info(f"model loaded: {round(n_params / 1e6, 1)}M params") + + hf_model.eval() + del state_dict + + return hf_model + + +@torch.no_grad() +def convert_wav2vec2_bert_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + config_path=None, + repo_id=None, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + if config_path is not None: + config = Wav2Vec2BertConfig.from_pretrained(config_path, hidden_act="swish") + else: + config = Wav2Vec2BertConfig(apply_spec_augment=False) + + hf_wav2vec = Wav2Vec2BertModel(config) + + model = load_conformer_shaw_model(checkpoint_path, dtype=torch.float32) + model.eval() + + hf_wav2vec = _convert_model(model, hf_wav2vec, wav2vec_convert_list) + + hf_wav2vec.save_pretrained(pytorch_dump_folder_path) + + if repo_id: + hf_wav2vec.push_to_hub(repo_id, create_pr=True) + + # save feature extractor + fe = SeamlessM4TFeatureExtractor(padding_value=1) + fe._set_processor_class("Wav2Vec2BertProcessor") + fe.save_pretrained(pytorch_dump_folder_path) + + if repo_id: + fe.push_to_hub(repo_id, create_pr=True) + + if args.audio_path: + waveform, sample_rate = torchaudio.load(args.audio_path) + waveform = torchaudio.functional.resample(waveform, sample_rate, fe.sampling_rate) + + fbank_converter = WaveformToFbankConverter( + num_mel_bins=80, + waveform_scale=2**15, + channel_last=True, + standardize=True, + dtype=torch.float32, + ) + collater = Collater(pad_value=1) + + decoded_audio = {"waveform": waveform.T, "sample_rate": fe.sampling_rate, "format": -1} + src = collater(fbank_converter(decoded_audio))["fbank"] + seqs, padding_mask = get_seqs_and_padding_mask(src) + + with torch.inference_mode(): + seqs, padding_mask = model.encoder_frontend(seqs, padding_mask) + original_output, padding_mask = model.encoder(seqs, padding_mask) + + hf_wav2vec.eval() + + inputs = fe(waveform, return_tensors="pt", padding=True) + with torch.no_grad(): + outputs = hf_wav2vec(**inputs) + + torch.testing.assert_close(original_output, outputs.last_hidden_state, rtol=5e-3, atol=5e-3) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--pytorch_dump_folder_path", + default=None, + type=str, + help="Path to the output PyTorch model.", + ) + parser.add_argument( + "--checkpoint_path", default="conformer_shaw", type=str, help="Path to seamless communication checkpoint" + ) + parser.add_argument( + "--config_path", + default=None, + type=str, + help="Path to hf config.json of model to convert", + ) + parser.add_argument("--repo_id", default=None, type=str, help="Push to this repo id if precised.") + parser.add_argument( + "--audio_path", + default=None, + type=str, + help="If specified, check that the original model and the converted model produce the same outputs.", + ) + + args = parser.parse_args() + convert_wav2vec2_bert_checkpoint( + args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.repo_id + ) diff --git a/docs/transformers/build/lib/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py b/docs/transformers/build/lib/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..d3d393f9baa146a5fd6e17a0b30d09c65d742d0d --- /dev/null +++ b/docs/transformers/build/lib/transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py @@ -0,0 +1,1625 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_wav2vec2_bert.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_peft_available, +) +from .configuration_wav2vec2_bert import Wav2Vec2BertConfig + + +# General docstring +_CONFIG_FOR_DOC = "Wav2Vec2BertConfig" + + +class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf + """ + + def __init__(self, config): + super().__init__() + dim = config.hidden_size // config.num_attention_heads + base = config.rotary_embedding_base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) + # Ignore copy + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + def forward(self, hidden_states): + sequence_length = hidden_states.shape[1] + + if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: + return self.cached_rotary_positional_embedding + + self.cached_sequence_length = sequence_length + # Embeddings are computed in the dtype of the inv_freq constant + time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + + cos_embeddings = embeddings.cos()[:, None, None, :] + sin_embeddings = embeddings.sin()[:, None, None, :] + # Computed embeddings are cast to the dtype of the hidden state inputs + self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states) + return self.cached_rotary_positional_embedding + + +class Wav2Vec2BertRelPositionalEmbedding(nn.Module): + """Relative positional encoding module.""" + + def __init__(self, config): + super().__init__() + self.max_len = config.max_source_positions + self.d_model = config.hidden_size + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) + + def extend_pe(self, x): + # Reset the positional encodings + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` is the position of query vector and `j` is the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i